diff options
Diffstat (limited to 'synapse')
102 files changed, 1664 insertions, 5400 deletions
diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 330de21f6b..f03fdd6dae 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -31,11 +31,6 @@ MAX_ALIAS_LENGTH = 255 # the maximum length for a user id is 255 characters MAX_USERID_LENGTH = 255 -# The maximum length for a group id is 255 characters -MAX_GROUPID_LENGTH = 255 -MAX_GROUP_CATEGORYID_LENGTH = 255 -MAX_GROUP_ROLEID_LENGTH = 255 - class Membership: @@ -142,7 +137,13 @@ class DeviceKeyAlgorithms: class EduTypes: - Presence: Final = "m.presence" + PRESENCE: Final = "m.presence" + TYPING: Final = "m.typing" + RECEIPT: Final = "m.receipt" + DEVICE_LIST_UPDATE: Final = "m.device_list_update" + SIGNING_KEY_UPDATE: Final = "m.signing_key_update" + UNSTABLE_SIGNING_KEY_UPDATE: Final = "org.matrix.signing_key_update" + DIRECT_TO_DEVICE: Final = "m.direct_to_device" class RejectedReason: diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 6650e826d5..cc7b785472 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -79,6 +79,13 @@ class Codes(str, Enum): WEAK_PASSWORD = "M_WEAK_PASSWORD" INVALID_SIGNATURE = "M_INVALID_SIGNATURE" USER_DEACTIVATED = "M_USER_DEACTIVATED" + + # The account has been suspended on the server. + # By opposition to `USER_DEACTIVATED`, this is a reversible measure + # that can possibly be appealed and reverted. + # Part of MSC3823. + USER_ACCOUNT_SUSPENDED = "ORG.MATRIX.MSC3823.USER_ACCOUNT_SUSPENDED" + BAD_ALIAS = "M_BAD_ALIAS" # For restricted join rules. UNABLE_AUTHORISE_JOIN = "M_UNABLE_TO_AUTHORISE_JOIN" @@ -139,7 +146,13 @@ class SynapseError(CodeMessageException): errcode: Matrix error code e.g 'M_FORBIDDEN' """ - def __init__(self, code: int, msg: str, errcode: str = Codes.UNKNOWN): + def __init__( + self, + code: int, + msg: str, + errcode: str = Codes.UNKNOWN, + additional_fields: Optional[Dict] = None, + ): """Constructs a synapse error. Args: @@ -149,9 +162,13 @@ class SynapseError(CodeMessageException): """ super().__init__(code, msg) self.errcode = errcode + if additional_fields is None: + self._additional_fields: Dict = {} + else: + self._additional_fields = dict(additional_fields) def error_dict(self) -> "JsonDict": - return cs_error(self.msg, self.errcode) + return cs_error(self.msg, self.errcode, **self._additional_fields) class InvalidAPICallError(SynapseError): @@ -176,14 +193,7 @@ class ProxiedRequestError(SynapseError): errcode: str = Codes.UNKNOWN, additional_fields: Optional[Dict] = None, ): - super().__init__(code, msg, errcode) - if additional_fields is None: - self._additional_fields: Dict = {} - else: - self._additional_fields = dict(additional_fields) - - def error_dict(self) -> "JsonDict": - return cs_error(self.msg, self.errcode, **self._additional_fields) + super().__init__(code, msg, errcode, additional_fields) class ConsentNotGivenError(SynapseError): diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index b91ce06de7..b007147519 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -33,7 +33,7 @@ from typing import ( import jsonschema from jsonschema import FormatChecker -from synapse.api.constants import EventContentFields +from synapse.api.constants import EduTypes, EventContentFields from synapse.api.errors import SynapseError from synapse.api.presence import UserPresenceState from synapse.events import EventBase @@ -347,7 +347,7 @@ class Filter: user_id = event.user_id field_matchers = { "senders": lambda v: user_id == v, - "types": lambda v: "m.presence" == v, + "types": lambda v: EduTypes.PRESENCE == v, } return self._check_fields(field_matchers) else: diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 2a9480a5c1..0a6dd618f6 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -69,7 +69,6 @@ from synapse.rest.admin import register_servlets_for_media_repo from synapse.rest.client import ( account_data, events, - groups, initial_sync, login, presence, @@ -78,6 +77,7 @@ from synapse.rest.client import ( read_marker, receipts, room, + room_batch, room_keys, sendtodevice, sync, @@ -87,7 +87,7 @@ from synapse.rest.client import ( voip, ) from synapse.rest.client._base import client_patterns -from synapse.rest.client.account import ThreepidRestServlet +from synapse.rest.client.account import ThreepidRestServlet, WhoamiRestServlet from synapse.rest.client.devices import DevicesRestServlet from synapse.rest.client.keys import ( KeyChangesServlet, @@ -289,6 +289,7 @@ class GenericWorkerServer(HomeServer): RegistrationTokenValidityRestServlet(self).register(resource) login.register_servlets(self, resource) ThreepidRestServlet(self).register(resource) + WhoamiRestServlet(self).register(resource) DevicesRestServlet(self).register(resource) # Read-only @@ -308,6 +309,7 @@ class GenericWorkerServer(HomeServer): room.register_servlets(self, resource, is_worker=True) room.register_deprecated_servlets(self, resource) initial_sync.register_servlets(self, resource) + room_batch.register_servlets(self, resource) room_keys.register_servlets(self, resource) tags.register_servlets(self, resource) account_data.register_servlets(self, resource) @@ -320,9 +322,6 @@ class GenericWorkerServer(HomeServer): presence.register_servlets(self, resource) - if self.config.experimental.groups_enabled: - groups.register_servlets(self, resource) - resources.update({CLIENT_API_PREFIX: resource}) resources.update(build_synapse_client_resource_tree(self)) diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index a610fb785d..ed92c2e910 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -23,13 +23,7 @@ from netaddr import IPSet from synapse.api.constants import EventTypes from synapse.events import EventBase -from synapse.types import ( - DeviceListUpdates, - GroupID, - JsonDict, - UserID, - get_domain_from_id, -) +from synapse.types import DeviceListUpdates, JsonDict, UserID from synapse.util.caches.descriptors import _CacheContext, cached if TYPE_CHECKING: @@ -55,7 +49,6 @@ class ApplicationServiceState(Enum): @attr.s(slots=True, frozen=True, auto_attribs=True) class Namespace: exclusive: bool - group_id: Optional[str] regex: Pattern[str] @@ -141,30 +134,13 @@ class ApplicationService: exclusive = regex_obj.get("exclusive") if not isinstance(exclusive, bool): raise ValueError("Expected bool for 'exclusive' in ns '%s'" % ns) - group_id = regex_obj.get("group_id") - if group_id: - if not isinstance(group_id, str): - raise ValueError( - "Expected string for 'group_id' in ns '%s'" % ns - ) - try: - GroupID.from_string(group_id) - except Exception: - raise ValueError( - "Expected valid group ID for 'group_id' in ns '%s'" % ns - ) - - if get_domain_from_id(group_id) != self.server_name: - raise ValueError( - "Expected 'group_id' to be this host in ns '%s'" % ns - ) regex = regex_obj.get("regex") if not isinstance(regex, str): raise ValueError("Expected string for 'regex' in ns '%s'" % ns) # Pre-compile regex. - result[ns].append(Namespace(exclusive, group_id, re.compile(regex))) + result[ns].append(Namespace(exclusive, re.compile(regex))) return result @@ -369,21 +345,6 @@ class ApplicationService: if namespace.exclusive ] - def get_groups_for_user(self, user_id: str) -> Iterable[str]: - """Get the groups that this user is associated with by this AS - - Args: - user_id: The ID of the user. - - Returns: - An iterable that yields group_id strings. - """ - return ( - namespace.group_id - for namespace in self.namespaces[ApplicationService.NS_USERS] - if namespace.group_id and namespace.regex.match(user_id) - ) - def is_rate_limited(self) -> bool: return self.rate_limited diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index d19f8dd996..df1c214462 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -14,7 +14,7 @@ # limitations under the License. import logging import urllib.parse -from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Tuple from prometheus_client import Counter from typing_extensions import TypeGuard @@ -155,6 +155,9 @@ class ApplicationServiceApi(SimpleHttpClient): if service.url is None: return [] + # This is required by the configuration. + assert service.hs_token is not None + uri = "%s%s/thirdparty/%s/%s" % ( service.url, APP_SERVICE_PREFIX, @@ -162,7 +165,11 @@ class ApplicationServiceApi(SimpleHttpClient): urllib.parse.quote(protocol), ) try: - response = await self.get_json(uri, fields) + args: Mapping[Any, Any] = { + **fields, + b"access_token": service.hs_token, + } + response = await self.get_json(uri, args=args) if not isinstance(response, list): logger.warning( "query_3pe to %s returned an invalid response %r", uri, response @@ -190,13 +197,15 @@ class ApplicationServiceApi(SimpleHttpClient): return {} async def _get() -> Optional[JsonDict]: + # This is required by the configuration. + assert service.hs_token is not None uri = "%s%s/thirdparty/protocol/%s" % ( service.url, APP_SERVICE_PREFIX, urllib.parse.quote(protocol), ) try: - info = await self.get_json(uri) + info = await self.get_json(uri, {"access_token": service.hs_token}) if not _is_valid_3pe_metadata(info): logger.warning( diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py index 3b49e60716..de5e5216c2 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py @@ -384,6 +384,11 @@ class _TransactionController: device_list_summary: The device list summary to include in the transaction. """ try: + service_is_up = await self._is_service_up(service) + # Don't create empty txns when in recovery mode (ephemeral events are dropped) + if not service_is_up and not events: + return + txn = await self.store.create_appservice_txn( service=service, events=events, @@ -393,7 +398,6 @@ class _TransactionController: unused_fallback_keys=unused_fallback_keys or {}, device_list_summary=device_list_summary or DeviceListUpdates(), ) - service_is_up = await self._is_service_up(service) if service_is_up: sent = await txn.send(self.as_api) if sent: diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi index 71d6655fda..01ea2b4dab 100644 --- a/synapse/config/_base.pyi +++ b/synapse/config/_base.pyi @@ -32,7 +32,6 @@ from synapse.config import ( emailconfig, experimental, federation, - groups, jwt, key, logger, @@ -107,7 +106,6 @@ class RootConfig: push: push.PushConfig spamchecker: spam_checker.SpamCheckerConfig room: room.RoomConfig - groups: groups.GroupsConfig userdirectory: user_directory.UserDirectoryConfig consent: consent.ConsentConfig stats: stats.StatsConfig diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index b20d949689..f2dfd49b07 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -73,9 +73,6 @@ class ExperimentalConfig(Config): # MSC3720 (Account status endpoint) self.msc3720_enabled: bool = experimental.get("msc3720_enabled", False) - # The deprecated groups feature. - self.groups_enabled: bool = experimental.get("groups_enabled", False) - # MSC2654: Unread counts self.msc2654_enabled: bool = experimental.get("msc2654_enabled", False) @@ -84,3 +81,6 @@ class ExperimentalConfig(Config): # MSC3786 (Add a default push rule to ignore m.room.server_acl events) self.msc3786_enabled: bool = experimental.get("msc3786_enabled", False) + + # MSC3772: A push rule for mutual relations. + self.msc3772_enabled: bool = experimental.get("msc3772_enabled", False) diff --git a/synapse/config/groups.py b/synapse/config/groups.py deleted file mode 100644 index c9b9c6daad..0000000000 --- a/synapse/config/groups.py +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright 2017 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any - -from synapse.types import JsonDict - -from ._base import Config - - -class GroupsConfig(Config): - section = "groups" - - def read_config(self, config: JsonDict, **kwargs: Any) -> None: - self.enable_group_creation = config.get("enable_group_creation", False) - self.group_creation_prefix = config.get("group_creation_prefix", "") - - def generate_config_section(self, **kwargs: Any) -> str: - return """\ - # Uncomment to allow non-server-admin users to create groups on this server - # - #enable_group_creation: true - - # If enabled, non server admins can only create groups with local parts - # starting with this prefix - # - #group_creation_prefix: "unofficial_" - """ diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index a4ec706908..4d2b298a70 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -25,7 +25,6 @@ from .database import DatabaseConfig from .emailconfig import EmailConfig from .experimental import ExperimentalConfig from .federation import FederationConfig -from .groups import GroupsConfig from .jwt import JWTConfig from .key import KeyConfig from .logger import LoggingConfig @@ -89,7 +88,6 @@ class HomeServerConfig(RootConfig): PushConfig, SpamCheckerConfig, RoomConfig, - GroupsConfig, UserDirectoryConfig, ConsentConfig, StatsConfig, diff --git a/synapse/config/server.py b/synapse/config/server.py index f73d5e1f66..657322cb1f 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -679,6 +679,17 @@ class ServerConfig(Config): config.get("exclude_rooms_from_sync") or [] ) + delete_stale_devices_after: Optional[str] = ( + config.get("delete_stale_devices_after") or None + ) + + if delete_stale_devices_after is not None: + self.delete_stale_devices_after: Optional[int] = self.parse_duration( + delete_stale_devices_after + ) + else: + self.delete_stale_devices_after = None + def has_tls_listener(self) -> bool: return any(listener.tls for listener in self.listeners) diff --git a/synapse/config/tracer.py b/synapse/config/tracer.py index 3472a9a01b..ae68a3dd1a 100644 --- a/synapse/config/tracer.py +++ b/synapse/config/tracer.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Set +from typing import Any, List, Set from synapse.types import JsonDict from synapse.util.check_dependencies import DependencyException, check_requirements @@ -49,7 +49,9 @@ class TracerConfig(Config): # The tracer is enabled so sanitize the config - self.opentracer_whitelist = opentracing_config.get("homeserver_whitelist", []) + self.opentracer_whitelist: List[str] = opentracing_config.get( + "homeserver_whitelist", [] + ) if not isinstance(self.opentracer_whitelist, list): raise ConfigError("Tracer homeserver_whitelist config is malformed") diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index 7a91544119..b700cbbfa1 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -22,7 +22,7 @@ from synapse.events import EventBase from synapse.types import JsonDict, StateMap if TYPE_CHECKING: - from synapse.storage import Storage + from synapse.storage.controllers import StorageControllers from synapse.storage.databases.main import DataStore from synapse.storage.state import StateFilter @@ -84,7 +84,7 @@ class EventContext: incomplete state. """ - _storage: "Storage" + _storage: "StorageControllers" rejected: Union[Literal[False], str] = False _state_group: Optional[int] = None state_group_before_event: Optional[int] = None @@ -97,7 +97,7 @@ class EventContext: @staticmethod def with_state( - storage: "Storage", + storage: "StorageControllers", state_group: Optional[int], state_group_before_event: Optional[int], state_delta_due_to_event: Optional[StateMap[str]], @@ -117,7 +117,7 @@ class EventContext: @staticmethod def for_outlier( - storage: "Storage", + storage: "StorageControllers", ) -> "EventContext": """Return an EventContext instance suitable for persisting an outlier event""" return EventContext(storage=storage) @@ -147,7 +147,7 @@ class EventContext: } @staticmethod - def deserialize(storage: "Storage", input: JsonDict) -> "EventContext": + def deserialize(storage: "StorageControllers", input: JsonDict) -> "EventContext": """Converts a dict that was produced by `serialize` back into a EventContext. diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py index 1048b4c825..d2e06c754e 100644 --- a/synapse/events/spamcheck.py +++ b/synapse/events/spamcheck.py @@ -21,6 +21,7 @@ from typing import ( Awaitable, Callable, Collection, + Dict, List, Optional, Tuple, @@ -41,12 +42,17 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) - CHECK_EVENT_FOR_SPAM_CALLBACK = Callable[ ["synapse.events.EventBase"], Awaitable[ Union[ str, + Codes, + # Highly experimental, not officially part of the spamchecker API, may + # disappear without warning depending on the results of ongoing + # experiments. + # Use this to return additional information as part of an error. + Tuple[Codes, Dict], # Deprecated bool, ] @@ -267,7 +273,9 @@ class SpamChecker: if check_media_file_for_spam is not None: self._check_media_file_for_spam_callbacks.append(check_media_file_for_spam) - async def check_event_for_spam(self, event: "synapse.events.EventBase") -> str: + async def check_event_for_spam( + self, event: "synapse.events.EventBase" + ) -> Union[Tuple[Codes, Dict], str]: """Checks if a given event is considered "spammy" by this server. If the server considers an event spammy, then it will be rejected if @@ -303,7 +311,7 @@ class SpamChecker: # mypy complains that we can't reach this code because of the # return type in CHECK_EVENT_FOR_SPAM_CALLBACK, but we don't know # for sure that the module actually returns it. - logger.warning( # type: ignore[unreachable] + logger.warning( "Module returned invalid value, rejecting message as spam" ) res = "This message has been rejected as probable spam" diff --git a/synapse/events/validator.py b/synapse/events/validator.py index 360d24274a..29fa9b3880 100644 --- a/synapse/events/validator.py +++ b/synapse/events/validator.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import collections.abc -from typing import Iterable, Type, Union +from typing import Iterable, Type, Union, cast import jsonschema @@ -103,7 +103,12 @@ class EventValidator: except jsonschema.ValidationError as e: if e.path: # example: "users_default": '0' is not of type 'integer' - message = '"' + e.path[-1] + '": ' + e.message # noqa: B306 + # cast safety: path entries can be integers, if we fail to validate + # items in an array. However the POWER_LEVELS_SCHEMA doesn't expect + # to see any arrays. + message = ( + '"' + cast(str, e.path[-1]) + '": ' + e.message # noqa: B306 + ) # jsonschema.ValidationError.message is a valid attribute else: # example: '0' is not of type 'integer' diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index b8232e5257..3ecede22d9 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -109,7 +109,6 @@ class FederationServer(FederationBase): super().__init__(hs) self.handler = hs.get_federation_handler() - self.storage = hs.get_storage() self._spam_checker = hs.get_spam_checker() self._federation_event_handler = hs.get_federation_event_handler() self.state = hs.get_state_handler() @@ -1353,7 +1352,7 @@ class FederationHandlerRegistry: self._edu_type_to_instance[edu_type] = instance_names async def on_edu(self, edu_type: str, origin: str, content: dict) -> None: - if not self.config.server.use_presence and edu_type == EduTypes.Presence: + if not self.config.server.use_presence and edu_type == EduTypes.PRESENCE: return # Check if we have a handler on this instance diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index d80f0ac5e8..333ca9a97f 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -21,6 +21,7 @@ from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Tupl import attr from prometheus_client import Counter +from synapse.api.constants import EduTypes from synapse.api.errors import ( FederationDeniedError, HttpResponseException, @@ -223,7 +224,7 @@ class PerDestinationQueue: """Marks that the destination has new data to send, without starting a new transaction. - If a transaction loop is already in progress then a new transcation will + If a transaction loop is already in progress then a new transaction will be attempted when the current one finishes. """ @@ -542,7 +543,7 @@ class PerDestinationQueue: edu = Edu( origin=self._server_name, destination=self._destination, - edu_type="m.receipt", + edu_type=EduTypes.RECEIPT, content=self._pending_rrs, ) self._pending_rrs = {} @@ -592,7 +593,7 @@ class PerDestinationQueue: Edu( origin=self._server_name, destination=self._destination, - edu_type="m.direct_to_device", + edu_type=EduTypes.DIRECT_TO_DEVICE, content=content, ) for content in contents @@ -670,7 +671,7 @@ class _TransactionQueueManager: Edu( origin=self.queue._server_name, destination=self.queue._destination, - edu_type="m.presence", + edu_type=EduTypes.PRESENCE, content={ "push": [ format_user_presence_state( diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py index 0c1cad86ab..75081810fd 100644 --- a/synapse/federation/sender/transaction_manager.py +++ b/synapse/federation/sender/transaction_manager.py @@ -16,6 +16,7 @@ from typing import TYPE_CHECKING, List from prometheus_client import Gauge +from synapse.api.constants import EduTypes from synapse.api.errors import HttpResponseException from synapse.events import EventBase from synapse.federation.persistence import TransactionActions @@ -126,7 +127,10 @@ class TransactionManager: len(edus), ) if issue_8631_logger.isEnabledFor(logging.DEBUG): - DEVICE_UPDATE_EDUS = {"m.device_list_update", "m.signing_key_update"} + DEVICE_UPDATE_EDUS = { + EduTypes.DEVICE_LIST_UPDATE, + EduTypes.SIGNING_KEY_UPDATE, + } device_list_updates = [ edu.content for edu in edus if edu.edu_type in DEVICE_UPDATE_EDUS ] diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 2686ee2e51..9e84bd677e 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -17,7 +17,6 @@ import logging import urllib from typing import ( Any, - Awaitable, Callable, Collection, Dict, @@ -49,11 +48,6 @@ from synapse.types import JsonDict logger = logging.getLogger(__name__) -# Send join responses can be huge, so we set a separate limit here. The response -# is parsed in a streaming manner, which helps alleviate the issue of memory -# usage a bit. -MAX_RESPONSE_SIZE_SEND_JOIN = 500 * 1024 * 1024 - class TransportLayerClient: """Sends federation HTTP requests to other servers""" @@ -349,7 +343,6 @@ class TransportLayerClient: path=path, data=content, parser=SendJoinParser(room_version, v1_api=True), - max_response_size=MAX_RESPONSE_SIZE_SEND_JOIN, ) async def send_join_v2( @@ -372,7 +365,6 @@ class TransportLayerClient: args=query_params, data=content, parser=SendJoinParser(room_version, v1_api=False), - max_response_size=MAX_RESPONSE_SIZE_SEND_JOIN, ) async def send_leave_v1( @@ -688,488 +680,6 @@ class TransportLayerClient: timeout=timeout, ) - async def get_group_profile( - self, destination: str, group_id: str, requester_user_id: str - ) -> JsonDict: - """Get a group profile""" - path = _create_v1_path("/groups/%s/profile", group_id) - - return await self.client.get_json( - destination=destination, - path=path, - args={"requester_user_id": requester_user_id}, - ignore_backoff=True, - ) - - async def update_group_profile( - self, destination: str, group_id: str, requester_user_id: str, content: JsonDict - ) -> JsonDict: - """Update a remote group profile - - Args: - destination - group_id - requester_user_id - content: The new profile of the group - """ - path = _create_v1_path("/groups/%s/profile", group_id) - - return self.client.post_json( - destination=destination, - path=path, - args={"requester_user_id": requester_user_id}, - data=content, - ignore_backoff=True, - ) - - async def get_group_summary( - self, destination: str, group_id: str, requester_user_id: str - ) -> JsonDict: - """Get a group summary""" - path = _create_v1_path("/groups/%s/summary", group_id) - - return await self.client.get_json( - destination=destination, - path=path, - args={"requester_user_id": requester_user_id}, - ignore_backoff=True, - ) - - async def get_rooms_in_group( - self, destination: str, group_id: str, requester_user_id: str - ) -> JsonDict: - """Get all rooms in a group""" - path = _create_v1_path("/groups/%s/rooms", group_id) - - return await self.client.get_json( - destination=destination, - path=path, - args={"requester_user_id": requester_user_id}, - ignore_backoff=True, - ) - - async def add_room_to_group( - self, - destination: str, - group_id: str, - requester_user_id: str, - room_id: str, - content: JsonDict, - ) -> JsonDict: - """Add a room to a group""" - path = _create_v1_path("/groups/%s/room/%s", group_id, room_id) - - return await self.client.post_json( - destination=destination, - path=path, - args={"requester_user_id": requester_user_id}, - data=content, - ignore_backoff=True, - ) - - async def update_room_in_group( - self, - destination: str, - group_id: str, - requester_user_id: str, - room_id: str, - config_key: str, - content: JsonDict, - ) -> JsonDict: - """Update room in group""" - path = _create_v1_path( - "/groups/%s/room/%s/config/%s", group_id, room_id, config_key - ) - - return await self.client.post_json( - destination=destination, - path=path, - args={"requester_user_id": requester_user_id}, - data=content, - ignore_backoff=True, - ) - - async def remove_room_from_group( - self, destination: str, group_id: str, requester_user_id: str, room_id: str - ) -> JsonDict: - """Remove a room from a group""" - path = _create_v1_path("/groups/%s/room/%s", group_id, room_id) - - return await self.client.delete_json( - destination=destination, - path=path, - args={"requester_user_id": requester_user_id}, - ignore_backoff=True, - ) - - async def get_users_in_group( - self, destination: str, group_id: str, requester_user_id: str - ) -> JsonDict: - """Get users in a group""" - path = _create_v1_path("/groups/%s/users", group_id) - - return await self.client.get_json( - destination=destination, - path=path, - args={"requester_user_id": requester_user_id}, - ignore_backoff=True, - ) - - async def get_invited_users_in_group( - self, destination: str, group_id: str, requester_user_id: str - ) -> JsonDict: - """Get users that have been invited to a group""" - path = _create_v1_path("/groups/%s/invited_users", group_id) - - return await self.client.get_json( - destination=destination, - path=path, - args={"requester_user_id": requester_user_id}, - ignore_backoff=True, - ) - - async def accept_group_invite( - self, destination: str, group_id: str, user_id: str, content: JsonDict - ) -> JsonDict: - """Accept a group invite""" - path = _create_v1_path("/groups/%s/users/%s/accept_invite", group_id, user_id) - - return await self.client.post_json( - destination=destination, path=path, data=content, ignore_backoff=True - ) - - def join_group( - self, destination: str, group_id: str, user_id: str, content: JsonDict - ) -> Awaitable[JsonDict]: - """Attempts to join a group""" - path = _create_v1_path("/groups/%s/users/%s/join", group_id, user_id) - - return self.client.post_json( - destination=destination, path=path, data=content, ignore_backoff=True - ) - - async def invite_to_group( - self, - destination: str, - group_id: str, - user_id: str, - requester_user_id: str, - content: JsonDict, - ) -> JsonDict: - """Invite a user to a group""" - path = _create_v1_path("/groups/%s/users/%s/invite", group_id, user_id) - - return await self.client.post_json( - destination=destination, - path=path, - args={"requester_user_id": requester_user_id}, - data=content, - ignore_backoff=True, - ) - - async def invite_to_group_notification( - self, destination: str, group_id: str, user_id: str, content: JsonDict - ) -> JsonDict: - """Sent by group server to inform a user's server that they have been - invited. - """ - - path = _create_v1_path("/groups/local/%s/users/%s/invite", group_id, user_id) - - return await self.client.post_json( - destination=destination, path=path, data=content, ignore_backoff=True - ) - - async def remove_user_from_group( - self, - destination: str, - group_id: str, - requester_user_id: str, - user_id: str, - content: JsonDict, - ) -> JsonDict: - """Remove a user from a group""" - path = _create_v1_path("/groups/%s/users/%s/remove", group_id, user_id) - - return await self.client.post_json( - destination=destination, - path=path, - args={"requester_user_id": requester_user_id}, - data=content, - ignore_backoff=True, - ) - - async def remove_user_from_group_notification( - self, destination: str, group_id: str, user_id: str, content: JsonDict - ) -> JsonDict: - """Sent by group server to inform a user's server that they have been - kicked from the group. - """ - - path = _create_v1_path("/groups/local/%s/users/%s/remove", group_id, user_id) - - return await self.client.post_json( - destination=destination, path=path, data=content, ignore_backoff=True - ) - - async def renew_group_attestation( - self, destination: str, group_id: str, user_id: str, content: JsonDict - ) -> JsonDict: - """Sent by either a group server or a user's server to periodically update - the attestations - """ - - path = _create_v1_path("/groups/%s/renew_attestation/%s", group_id, user_id) - - return await self.client.post_json( - destination=destination, path=path, data=content, ignore_backoff=True - ) - - async def update_group_summary_room( - self, - destination: str, - group_id: str, - user_id: str, - room_id: str, - category_id: str, - content: JsonDict, - ) -> JsonDict: - """Update a room entry in a group summary""" - if category_id: - path = _create_v1_path( - "/groups/%s/summary/categories/%s/rooms/%s", - group_id, - category_id, - room_id, - ) - else: - path = _create_v1_path("/groups/%s/summary/rooms/%s", group_id, room_id) - - return await self.client.post_json( - destination=destination, - path=path, - args={"requester_user_id": user_id}, - data=content, - ignore_backoff=True, - ) - - async def delete_group_summary_room( - self, - destination: str, - group_id: str, - user_id: str, - room_id: str, - category_id: str, - ) -> JsonDict: - """Delete a room entry in a group summary""" - if category_id: - path = _create_v1_path( - "/groups/%s/summary/categories/%s/rooms/%s", - group_id, - category_id, - room_id, - ) - else: - path = _create_v1_path("/groups/%s/summary/rooms/%s", group_id, room_id) - - return await self.client.delete_json( - destination=destination, - path=path, - args={"requester_user_id": user_id}, - ignore_backoff=True, - ) - - async def get_group_categories( - self, destination: str, group_id: str, requester_user_id: str - ) -> JsonDict: - """Get all categories in a group""" - path = _create_v1_path("/groups/%s/categories", group_id) - - return await self.client.get_json( - destination=destination, - path=path, - args={"requester_user_id": requester_user_id}, - ignore_backoff=True, - ) - - async def get_group_category( - self, destination: str, group_id: str, requester_user_id: str, category_id: str - ) -> JsonDict: - """Get category info in a group""" - path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id) - - return await self.client.get_json( - destination=destination, - path=path, - args={"requester_user_id": requester_user_id}, - ignore_backoff=True, - ) - - async def update_group_category( - self, - destination: str, - group_id: str, - requester_user_id: str, - category_id: str, - content: JsonDict, - ) -> JsonDict: - """Update a category in a group""" - path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id) - - return await self.client.post_json( - destination=destination, - path=path, - args={"requester_user_id": requester_user_id}, - data=content, - ignore_backoff=True, - ) - - async def delete_group_category( - self, destination: str, group_id: str, requester_user_id: str, category_id: str - ) -> JsonDict: - """Delete a category in a group""" - path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id) - - return await self.client.delete_json( - destination=destination, - path=path, - args={"requester_user_id": requester_user_id}, - ignore_backoff=True, - ) - - async def get_group_roles( - self, destination: str, group_id: str, requester_user_id: str - ) -> JsonDict: - """Get all roles in a group""" - path = _create_v1_path("/groups/%s/roles", group_id) - - return await self.client.get_json( - destination=destination, - path=path, - args={"requester_user_id": requester_user_id}, - ignore_backoff=True, - ) - - async def get_group_role( - self, destination: str, group_id: str, requester_user_id: str, role_id: str - ) -> JsonDict: - """Get a roles info""" - path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id) - - return await self.client.get_json( - destination=destination, - path=path, - args={"requester_user_id": requester_user_id}, - ignore_backoff=True, - ) - - async def update_group_role( - self, - destination: str, - group_id: str, - requester_user_id: str, - role_id: str, - content: JsonDict, - ) -> JsonDict: - """Update a role in a group""" - path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id) - - return await self.client.post_json( - destination=destination, - path=path, - args={"requester_user_id": requester_user_id}, - data=content, - ignore_backoff=True, - ) - - async def delete_group_role( - self, destination: str, group_id: str, requester_user_id: str, role_id: str - ) -> JsonDict: - """Delete a role in a group""" - path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id) - - return await self.client.delete_json( - destination=destination, - path=path, - args={"requester_user_id": requester_user_id}, - ignore_backoff=True, - ) - - async def update_group_summary_user( - self, - destination: str, - group_id: str, - requester_user_id: str, - user_id: str, - role_id: str, - content: JsonDict, - ) -> JsonDict: - """Update a users entry in a group""" - if role_id: - path = _create_v1_path( - "/groups/%s/summary/roles/%s/users/%s", group_id, role_id, user_id - ) - else: - path = _create_v1_path("/groups/%s/summary/users/%s", group_id, user_id) - - return await self.client.post_json( - destination=destination, - path=path, - args={"requester_user_id": requester_user_id}, - data=content, - ignore_backoff=True, - ) - - async def set_group_join_policy( - self, destination: str, group_id: str, requester_user_id: str, content: JsonDict - ) -> JsonDict: - """Sets the join policy for a group""" - path = _create_v1_path("/groups/%s/settings/m.join_policy", group_id) - - return await self.client.put_json( - destination=destination, - path=path, - args={"requester_user_id": requester_user_id}, - data=content, - ignore_backoff=True, - ) - - async def delete_group_summary_user( - self, - destination: str, - group_id: str, - requester_user_id: str, - user_id: str, - role_id: str, - ) -> JsonDict: - """Delete a users entry in a group""" - if role_id: - path = _create_v1_path( - "/groups/%s/summary/roles/%s/users/%s", group_id, role_id, user_id - ) - else: - path = _create_v1_path("/groups/%s/summary/users/%s", group_id, user_id) - - return await self.client.delete_json( - destination=destination, - path=path, - args={"requester_user_id": requester_user_id}, - ignore_backoff=True, - ) - - async def bulk_get_publicised_groups( - self, destination: str, user_ids: Iterable[str] - ) -> JsonDict: - """Get the groups a list of users are publicising""" - - path = _create_v1_path("/get_groups_publicised") - - content = {"user_ids": user_ids} - - return await self.client.post_json( - destination=destination, path=path, data=content, ignore_backoff=True - ) - async def get_room_complexity(self, destination: str, room_id: str) -> JsonDict: """ Args: @@ -1360,6 +870,11 @@ class SendJoinParser(ByteParser[SendJoinResponse]): CONTENT_TYPE = "application/json" + # /send_join responses can be huge, so we override the size limit here. The response + # is parsed in a streaming manner, which helps alleviate the issue of memory + # usage a bit. + MAX_RESPONSE_SIZE = 500 * 1024 * 1024 + def __init__(self, room_version: RoomVersion, v1_api: bool): self._response = SendJoinResponse([], [], event_dict={}) self._room_version = room_version @@ -1430,6 +945,9 @@ class _StateParser(ByteParser[StateRequestResponse]): CONTENT_TYPE = "application/json" + # As with /send_join, /state responses can be huge. + MAX_RESPONSE_SIZE = 500 * 1024 * 1024 + def __init__(self, room_version: RoomVersion): self._response = StateRequestResponse([], []) self._room_version = room_version diff --git a/synapse/federation/transport/server/__init__.py b/synapse/federation/transport/server/__init__.py index 71b2f90eb9..50623cd385 100644 --- a/synapse/federation/transport/server/__init__.py +++ b/synapse/federation/transport/server/__init__.py @@ -27,10 +27,6 @@ from synapse.federation.transport.server.federation import ( FederationAccountStatusServlet, FederationTimestampLookupServlet, ) -from synapse.federation.transport.server.groups_local import GROUP_LOCAL_SERVLET_CLASSES -from synapse.federation.transport.server.groups_server import ( - GROUP_SERVER_SERVLET_CLASSES, -) from synapse.http.server import HttpServer, JsonResource from synapse.http.servlet import ( parse_boolean_from_args, @@ -199,38 +195,6 @@ class PublicRoomList(BaseFederationServlet): return 200, data -class FederationGroupsRenewAttestaionServlet(BaseFederationServlet): - """A group or user's server renews their attestation""" - - PATH = "/groups/(?P<group_id>[^/]*)/renew_attestation/(?P<user_id>[^/]*)" - - def __init__( - self, - hs: "HomeServer", - authenticator: Authenticator, - ratelimiter: FederationRateLimiter, - server_name: str, - ): - super().__init__(hs, authenticator, ratelimiter, server_name) - self.handler = hs.get_groups_attestation_renewer() - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - user_id: str, - ) -> Tuple[int, JsonDict]: - # We don't need to check auth here as we check the attestation signatures - - new_content = await self.handler.on_renew_attestation( - group_id, user_id, content - ) - - return 200, new_content - - class OpenIdUserInfo(BaseFederationServlet): """ Exchange a bearer token for information about a user. @@ -292,16 +256,9 @@ class OpenIdUserInfo(BaseFederationServlet): SERVLET_GROUPS: Dict[str, Iterable[Type[BaseFederationServlet]]] = { "federation": FEDERATION_SERVLET_CLASSES, "room_list": (PublicRoomList,), - "group_server": GROUP_SERVER_SERVLET_CLASSES, - "group_local": GROUP_LOCAL_SERVLET_CLASSES, - "group_attestation": (FederationGroupsRenewAttestaionServlet,), "openid": (OpenIdUserInfo,), } -DEFAULT_SERVLET_GROUPS = ("federation", "room_list", "openid") - -GROUP_SERVLET_GROUPS = ("group_server", "group_local", "group_attestation") - def register_servlets( hs: "HomeServer", @@ -324,10 +281,7 @@ def register_servlets( Defaults to ``DEFAULT_SERVLET_GROUPS``. """ if not servlet_groups: - servlet_groups = DEFAULT_SERVLET_GROUPS - # Only allow the groups servlets if the deprecated groups feature is enabled. - if hs.config.experimental.groups_enabled: - servlet_groups = servlet_groups + GROUP_SERVLET_GROUPS + servlet_groups = SERVLET_GROUPS.keys() for servlet_group in servlet_groups: # Skip unknown servlet groups. diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py index 6fbc7b5f15..7dfb890661 100644 --- a/synapse/federation/transport/server/federation.py +++ b/synapse/federation/transport/server/federation.py @@ -27,6 +27,7 @@ from typing import ( from matrix_common.versionstring import get_distribution_version_string from typing_extensions import Literal +from synapse.api.constants import EduTypes from synapse.api.errors import Codes, SynapseError from synapse.api.room_versions import RoomVersions from synapse.api.urls import FEDERATION_UNSTABLE_PREFIX, FEDERATION_V2_PREFIX @@ -108,7 +109,10 @@ class FederationSendServlet(BaseFederationServerServlet): ) if issue_8631_logger.isEnabledFor(logging.DEBUG): - DEVICE_UPDATE_EDUS = ["m.device_list_update", "m.signing_key_update"] + DEVICE_UPDATE_EDUS = [ + EduTypes.DEVICE_LIST_UPDATE, + EduTypes.SIGNING_KEY_UPDATE, + ] device_list_updates = [ edu.get("content", {}) for edu in transaction_data.get("edus", []) @@ -650,10 +654,6 @@ class FederationRoomHierarchyServlet(BaseFederationServlet): ) -class FederationRoomHierarchyUnstableServlet(FederationRoomHierarchyServlet): - PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc2946" - - class RoomComplexityServlet(BaseFederationServlet): """ Indicates to other servers how complex (and therefore likely @@ -752,7 +752,6 @@ FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = ( FederationVersionServlet, RoomComplexityServlet, FederationRoomHierarchyServlet, - FederationRoomHierarchyUnstableServlet, FederationV1SendKnockServlet, FederationMakeKnockServlet, FederationAccountStatusServlet, diff --git a/synapse/federation/transport/server/groups_local.py b/synapse/federation/transport/server/groups_local.py deleted file mode 100644 index 496472e1dc..0000000000 --- a/synapse/federation/transport/server/groups_local.py +++ /dev/null @@ -1,115 +0,0 @@ -# Copyright 2021 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import TYPE_CHECKING, Dict, List, Tuple, Type - -from synapse.api.errors import SynapseError -from synapse.federation.transport.server._base import ( - Authenticator, - BaseFederationServlet, -) -from synapse.handlers.groups_local import GroupsLocalHandler -from synapse.types import JsonDict, get_domain_from_id -from synapse.util.ratelimitutils import FederationRateLimiter - -if TYPE_CHECKING: - from synapse.server import HomeServer - - -class BaseGroupsLocalServlet(BaseFederationServlet): - """Abstract base class for federation servlet classes which provides a groups local handler. - - See BaseFederationServlet for more information. - """ - - def __init__( - self, - hs: "HomeServer", - authenticator: Authenticator, - ratelimiter: FederationRateLimiter, - server_name: str, - ): - super().__init__(hs, authenticator, ratelimiter, server_name) - self.handler = hs.get_groups_local_handler() - - -class FederationGroupsLocalInviteServlet(BaseGroupsLocalServlet): - """A group server has invited a local user""" - - PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite" - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - user_id: str, - ) -> Tuple[int, JsonDict]: - if get_domain_from_id(group_id) != origin: - raise SynapseError(403, "group_id doesn't match origin") - - assert isinstance( - self.handler, GroupsLocalHandler - ), "Workers cannot handle group invites." - - new_content = await self.handler.on_invite(group_id, user_id, content) - - return 200, new_content - - -class FederationGroupsRemoveLocalUserServlet(BaseGroupsLocalServlet): - """A group server has removed a local user""" - - PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove" - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - user_id: str, - ) -> Tuple[int, None]: - if get_domain_from_id(group_id) != origin: - raise SynapseError(403, "user_id doesn't match origin") - - assert isinstance( - self.handler, GroupsLocalHandler - ), "Workers cannot handle group removals." - - await self.handler.user_removed_from_group(group_id, user_id, content) - - return 200, None - - -class FederationGroupsBulkPublicisedServlet(BaseGroupsLocalServlet): - """Get roles in a group""" - - PATH = "/get_groups_publicised" - - async def on_POST( - self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]] - ) -> Tuple[int, JsonDict]: - resp = await self.handler.bulk_get_publicised_groups( - content["user_ids"], proxy=False - ) - - return 200, resp - - -GROUP_LOCAL_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = ( - FederationGroupsLocalInviteServlet, - FederationGroupsRemoveLocalUserServlet, - FederationGroupsBulkPublicisedServlet, -) diff --git a/synapse/federation/transport/server/groups_server.py b/synapse/federation/transport/server/groups_server.py deleted file mode 100644 index 851b50152e..0000000000 --- a/synapse/federation/transport/server/groups_server.py +++ /dev/null @@ -1,755 +0,0 @@ -# Copyright 2021 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import TYPE_CHECKING, Dict, List, Tuple, Type - -from typing_extensions import Literal - -from synapse.api.constants import MAX_GROUP_CATEGORYID_LENGTH, MAX_GROUP_ROLEID_LENGTH -from synapse.api.errors import Codes, SynapseError -from synapse.federation.transport.server._base import ( - Authenticator, - BaseFederationServlet, -) -from synapse.http.servlet import parse_string_from_args -from synapse.types import JsonDict, get_domain_from_id -from synapse.util.ratelimitutils import FederationRateLimiter - -if TYPE_CHECKING: - from synapse.server import HomeServer - - -class BaseGroupsServerServlet(BaseFederationServlet): - """Abstract base class for federation servlet classes which provides a groups server handler. - - See BaseFederationServlet for more information. - """ - - def __init__( - self, - hs: "HomeServer", - authenticator: Authenticator, - ratelimiter: FederationRateLimiter, - server_name: str, - ): - super().__init__(hs, authenticator, ratelimiter, server_name) - self.handler = hs.get_groups_server_handler() - - -class FederationGroupsProfileServlet(BaseGroupsServerServlet): - """Get/set the basic profile of a group on behalf of a user""" - - PATH = "/groups/(?P<group_id>[^/]*)/profile" - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - group_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - new_content = await self.handler.get_group_profile(group_id, requester_user_id) - - return 200, new_content - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - new_content = await self.handler.update_group_profile( - group_id, requester_user_id, content - ) - - return 200, new_content - - -class FederationGroupsSummaryServlet(BaseGroupsServerServlet): - PATH = "/groups/(?P<group_id>[^/]*)/summary" - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - group_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - new_content = await self.handler.get_group_summary(group_id, requester_user_id) - - return 200, new_content - - -class FederationGroupsRoomsServlet(BaseGroupsServerServlet): - """Get the rooms in a group on behalf of a user""" - - PATH = "/groups/(?P<group_id>[^/]*)/rooms" - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - group_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - new_content = await self.handler.get_rooms_in_group(group_id, requester_user_id) - - return 200, new_content - - -class FederationGroupsAddRoomsServlet(BaseGroupsServerServlet): - """Add/remove room from group""" - - PATH = "/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)" - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - room_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - new_content = await self.handler.add_room_to_group( - group_id, requester_user_id, room_id, content - ) - - return 200, new_content - - async def on_DELETE( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - group_id: str, - room_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - new_content = await self.handler.remove_room_from_group( - group_id, requester_user_id, room_id - ) - - return 200, new_content - - -class FederationGroupsAddRoomsConfigServlet(BaseGroupsServerServlet): - """Update room config in group""" - - PATH = ( - "/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)" - "/config/(?P<config_key>[^/]*)" - ) - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - room_id: str, - config_key: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - result = await self.handler.update_room_in_group( - group_id, requester_user_id, room_id, config_key, content - ) - - return 200, result - - -class FederationGroupsUsersServlet(BaseGroupsServerServlet): - """Get the users in a group on behalf of a user""" - - PATH = "/groups/(?P<group_id>[^/]*)/users" - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - group_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - new_content = await self.handler.get_users_in_group(group_id, requester_user_id) - - return 200, new_content - - -class FederationGroupsInvitedUsersServlet(BaseGroupsServerServlet): - """Get the users that have been invited to a group""" - - PATH = "/groups/(?P<group_id>[^/]*)/invited_users" - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - group_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - new_content = await self.handler.get_invited_users_in_group( - group_id, requester_user_id - ) - - return 200, new_content - - -class FederationGroupsInviteServlet(BaseGroupsServerServlet): - """Ask a group server to invite someone to the group""" - - PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite" - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - user_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - new_content = await self.handler.invite_to_group( - group_id, user_id, requester_user_id, content - ) - - return 200, new_content - - -class FederationGroupsAcceptInviteServlet(BaseGroupsServerServlet): - """Accept an invitation from the group server""" - - PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/accept_invite" - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - user_id: str, - ) -> Tuple[int, JsonDict]: - if get_domain_from_id(user_id) != origin: - raise SynapseError(403, "user_id doesn't match origin") - - new_content = await self.handler.accept_invite(group_id, user_id, content) - - return 200, new_content - - -class FederationGroupsJoinServlet(BaseGroupsServerServlet): - """Attempt to join a group""" - - PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/join" - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - user_id: str, - ) -> Tuple[int, JsonDict]: - if get_domain_from_id(user_id) != origin: - raise SynapseError(403, "user_id doesn't match origin") - - new_content = await self.handler.join_group(group_id, user_id, content) - - return 200, new_content - - -class FederationGroupsRemoveUserServlet(BaseGroupsServerServlet): - """Leave or kick a user from the group""" - - PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove" - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - user_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - new_content = await self.handler.remove_user_from_group( - group_id, user_id, requester_user_id, content - ) - - return 200, new_content - - -class FederationGroupsSummaryRoomsServlet(BaseGroupsServerServlet): - """Add/remove a room from the group summary, with optional category. - - Matches both: - - /groups/:group/summary/rooms/:room_id - - /groups/:group/summary/categories/:category/rooms/:room_id - """ - - PATH = ( - "/groups/(?P<group_id>[^/]*)/summary" - "(/categories/(?P<category_id>[^/]+))?" - "/rooms/(?P<room_id>[^/]*)" - ) - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - category_id: str, - room_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - if category_id == "": - raise SynapseError( - 400, "category_id cannot be empty string", Codes.INVALID_PARAM - ) - - if len(category_id) > MAX_GROUP_CATEGORYID_LENGTH: - raise SynapseError( - 400, - "category_id may not be longer than %s characters" - % (MAX_GROUP_CATEGORYID_LENGTH,), - Codes.INVALID_PARAM, - ) - - resp = await self.handler.update_group_summary_room( - group_id, - requester_user_id, - room_id=room_id, - category_id=category_id, - content=content, - ) - - return 200, resp - - async def on_DELETE( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - group_id: str, - category_id: str, - room_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - if category_id == "": - raise SynapseError(400, "category_id cannot be empty string") - - resp = await self.handler.delete_group_summary_room( - group_id, requester_user_id, room_id=room_id, category_id=category_id - ) - - return 200, resp - - -class FederationGroupsCategoriesServlet(BaseGroupsServerServlet): - """Get all categories for a group""" - - PATH = "/groups/(?P<group_id>[^/]*)/categories/?" - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - group_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - resp = await self.handler.get_group_categories(group_id, requester_user_id) - - return 200, resp - - -class FederationGroupsCategoryServlet(BaseGroupsServerServlet): - """Add/remove/get a category in a group""" - - PATH = "/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)" - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - group_id: str, - category_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - resp = await self.handler.get_group_category( - group_id, requester_user_id, category_id - ) - - return 200, resp - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - category_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - if category_id == "": - raise SynapseError(400, "category_id cannot be empty string") - - if len(category_id) > MAX_GROUP_CATEGORYID_LENGTH: - raise SynapseError( - 400, - "category_id may not be longer than %s characters" - % (MAX_GROUP_CATEGORYID_LENGTH,), - Codes.INVALID_PARAM, - ) - - resp = await self.handler.upsert_group_category( - group_id, requester_user_id, category_id, content - ) - - return 200, resp - - async def on_DELETE( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - group_id: str, - category_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - if category_id == "": - raise SynapseError(400, "category_id cannot be empty string") - - resp = await self.handler.delete_group_category( - group_id, requester_user_id, category_id - ) - - return 200, resp - - -class FederationGroupsRolesServlet(BaseGroupsServerServlet): - """Get roles in a group""" - - PATH = "/groups/(?P<group_id>[^/]*)/roles/?" - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - group_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - resp = await self.handler.get_group_roles(group_id, requester_user_id) - - return 200, resp - - -class FederationGroupsRoleServlet(BaseGroupsServerServlet): - """Add/remove/get a role in a group""" - - PATH = "/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)" - - async def on_GET( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - group_id: str, - role_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - resp = await self.handler.get_group_role(group_id, requester_user_id, role_id) - - return 200, resp - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - role_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - if role_id == "": - raise SynapseError( - 400, "role_id cannot be empty string", Codes.INVALID_PARAM - ) - - if len(role_id) > MAX_GROUP_ROLEID_LENGTH: - raise SynapseError( - 400, - "role_id may not be longer than %s characters" - % (MAX_GROUP_ROLEID_LENGTH,), - Codes.INVALID_PARAM, - ) - - resp = await self.handler.update_group_role( - group_id, requester_user_id, role_id, content - ) - - return 200, resp - - async def on_DELETE( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - group_id: str, - role_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - if role_id == "": - raise SynapseError(400, "role_id cannot be empty string") - - resp = await self.handler.delete_group_role( - group_id, requester_user_id, role_id - ) - - return 200, resp - - -class FederationGroupsSummaryUsersServlet(BaseGroupsServerServlet): - """Add/remove a user from the group summary, with optional role. - - Matches both: - - /groups/:group/summary/users/:user_id - - /groups/:group/summary/roles/:role/users/:user_id - """ - - PATH = ( - "/groups/(?P<group_id>[^/]*)/summary" - "(/roles/(?P<role_id>[^/]+))?" - "/users/(?P<user_id>[^/]*)" - ) - - async def on_POST( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - role_id: str, - user_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - if role_id == "": - raise SynapseError(400, "role_id cannot be empty string") - - if len(role_id) > MAX_GROUP_ROLEID_LENGTH: - raise SynapseError( - 400, - "role_id may not be longer than %s characters" - % (MAX_GROUP_ROLEID_LENGTH,), - Codes.INVALID_PARAM, - ) - - resp = await self.handler.update_group_summary_user( - group_id, - requester_user_id, - user_id=user_id, - role_id=role_id, - content=content, - ) - - return 200, resp - - async def on_DELETE( - self, - origin: str, - content: Literal[None], - query: Dict[bytes, List[bytes]], - group_id: str, - role_id: str, - user_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - if role_id == "": - raise SynapseError(400, "role_id cannot be empty string") - - resp = await self.handler.delete_group_summary_user( - group_id, requester_user_id, user_id=user_id, role_id=role_id - ) - - return 200, resp - - -class FederationGroupsSettingJoinPolicyServlet(BaseGroupsServerServlet): - """Sets whether a group is joinable without an invite or knock""" - - PATH = "/groups/(?P<group_id>[^/]*)/settings/m.join_policy" - - async def on_PUT( - self, - origin: str, - content: JsonDict, - query: Dict[bytes, List[bytes]], - group_id: str, - ) -> Tuple[int, JsonDict]: - requester_user_id = parse_string_from_args( - query, "requester_user_id", required=True - ) - if get_domain_from_id(requester_user_id) != origin: - raise SynapseError(403, "requester_user_id doesn't match origin") - - new_content = await self.handler.set_group_join_policy( - group_id, requester_user_id, content - ) - - return 200, new_content - - -GROUP_SERVER_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = ( - FederationGroupsProfileServlet, - FederationGroupsSummaryServlet, - FederationGroupsRoomsServlet, - FederationGroupsUsersServlet, - FederationGroupsInvitedUsersServlet, - FederationGroupsInviteServlet, - FederationGroupsAcceptInviteServlet, - FederationGroupsJoinServlet, - FederationGroupsRemoveUserServlet, - FederationGroupsSummaryRoomsServlet, - FederationGroupsCategoriesServlet, - FederationGroupsCategoryServlet, - FederationGroupsRolesServlet, - FederationGroupsRoleServlet, - FederationGroupsSummaryUsersServlet, - FederationGroupsAddRoomsServlet, - FederationGroupsAddRoomsConfigServlet, - FederationGroupsSettingJoinPolicyServlet, -) diff --git a/synapse/groups/__init__.py b/synapse/groups/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 --- a/synapse/groups/__init__.py +++ /dev/null diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py deleted file mode 100644 index ed26d6a6ce..0000000000 --- a/synapse/groups/attestations.py +++ /dev/null @@ -1,218 +0,0 @@ -# Copyright 2017 Vector Creations Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Attestations ensure that users and groups can't lie about their memberships. - -When a user joins a group the HS and GS swap attestations, which allow them -both to independently prove to third parties their membership.These -attestations have a validity period so need to be periodically renewed. - -If a user leaves (or gets kicked out of) a group, either side can still use -their attestation to "prove" their membership, until the attestation expires. -Therefore attestations shouldn't be relied on to prove membership in important -cases, but can for less important situations, e.g. showing a users membership -of groups on their profile, showing flairs, etc. - -An attestation is a signed blob of json that looks like: - - { - "user_id": "@foo:a.example.com", - "group_id": "+bar:b.example.com", - "valid_until_ms": 1507994728530, - "signatures":{"matrix.org":{"ed25519:auto":"..."}} - } -""" - -import logging -import random -from typing import TYPE_CHECKING, Optional, Tuple - -from signedjson.sign import sign_json - -from twisted.internet.defer import Deferred - -from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError -from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.types import JsonDict, get_domain_from_id - -if TYPE_CHECKING: - from synapse.server import HomeServer - -logger = logging.getLogger(__name__) - - -# Default validity duration for new attestations we create -DEFAULT_ATTESTATION_LENGTH_MS = 3 * 24 * 60 * 60 * 1000 - -# We add some jitter to the validity duration of attestations so that if we -# add lots of users at once we don't need to renew them all at once. -# The jitter is a multiplier picked randomly between the first and second number -DEFAULT_ATTESTATION_JITTER = (0.9, 1.3) - -# Start trying to update our attestations when they come this close to expiring -UPDATE_ATTESTATION_TIME_MS = 1 * 24 * 60 * 60 * 1000 - - -class GroupAttestationSigning: - """Creates and verifies group attestations.""" - - def __init__(self, hs: "HomeServer"): - self.keyring = hs.get_keyring() - self.clock = hs.get_clock() - self.server_name = hs.hostname - self.signing_key = hs.signing_key - - async def verify_attestation( - self, - attestation: JsonDict, - group_id: str, - user_id: str, - server_name: Optional[str] = None, - ) -> None: - """Verifies that the given attestation matches the given parameters. - - An optional server_name can be supplied to explicitly set which server's - signature is expected. Otherwise assumes that either the group_id or user_id - is local and uses the other's server as the one to check. - """ - - if not server_name: - if get_domain_from_id(group_id) == self.server_name: - server_name = get_domain_from_id(user_id) - elif get_domain_from_id(user_id) == self.server_name: - server_name = get_domain_from_id(group_id) - else: - raise Exception("Expected either group_id or user_id to be local") - - if user_id != attestation["user_id"]: - raise SynapseError(400, "Attestation has incorrect user_id") - - if group_id != attestation["group_id"]: - raise SynapseError(400, "Attestation has incorrect group_id") - valid_until_ms = attestation["valid_until_ms"] - - # TODO: We also want to check that *new* attestations that people give - # us to store are valid for at least a little while. - now = self.clock.time_msec() - if valid_until_ms < now: - raise SynapseError(400, "Attestation expired") - - assert server_name is not None - await self.keyring.verify_json_for_server( - server_name, - attestation, - now, - ) - - def create_attestation(self, group_id: str, user_id: str) -> JsonDict: - """Create an attestation for the group_id and user_id with default - validity length. - """ - validity_period = DEFAULT_ATTESTATION_LENGTH_MS * random.uniform( - *DEFAULT_ATTESTATION_JITTER - ) - valid_until_ms = int(self.clock.time_msec() + validity_period) - - return sign_json( - { - "group_id": group_id, - "user_id": user_id, - "valid_until_ms": valid_until_ms, - }, - self.server_name, - self.signing_key, - ) - - -class GroupAttestionRenewer: - """Responsible for sending and receiving attestation updates.""" - - def __init__(self, hs: "HomeServer"): - self.clock = hs.get_clock() - self.store = hs.get_datastores().main - self.assestations = hs.get_groups_attestation_signing() - self.transport_client = hs.get_federation_transport_client() - self.is_mine_id = hs.is_mine_id - self.attestations = hs.get_groups_attestation_signing() - - if not hs.config.worker.worker_app: - self._renew_attestations_loop = self.clock.looping_call( - self._start_renew_attestations, 30 * 60 * 1000 - ) - - async def on_renew_attestation( - self, group_id: str, user_id: str, content: JsonDict - ) -> JsonDict: - """When a remote updates an attestation""" - attestation = content["attestation"] - - if not self.is_mine_id(group_id) and not self.is_mine_id(user_id): - raise SynapseError(400, "Neither user not group are on this server") - - await self.attestations.verify_attestation( - attestation, user_id=user_id, group_id=group_id - ) - - await self.store.update_remote_attestion(group_id, user_id, attestation) - - return {} - - def _start_renew_attestations(self) -> "Deferred[None]": - return run_as_background_process("renew_attestations", self._renew_attestations) - - async def _renew_attestations(self) -> None: - """Called periodically to check if we need to update any of our attestations""" - - now = self.clock.time_msec() - - rows = await self.store.get_attestations_need_renewals( - now + UPDATE_ATTESTATION_TIME_MS - ) - - async def _renew_attestation(group_user: Tuple[str, str]) -> None: - group_id, user_id = group_user - try: - if not self.is_mine_id(group_id): - destination = get_domain_from_id(group_id) - elif not self.is_mine_id(user_id): - destination = get_domain_from_id(user_id) - else: - logger.warning( - "Incorrectly trying to do attestations for user: %r in %r", - user_id, - group_id, - ) - await self.store.remove_attestation_renewal(group_id, user_id) - return - - attestation = self.attestations.create_attestation(group_id, user_id) - - await self.transport_client.renew_group_attestation( - destination, group_id, user_id, content={"attestation": attestation} - ) - - await self.store.update_attestation_renewal( - group_id, user_id, attestation - ) - except (RequestSendFailed, HttpResponseException) as e: - logger.warning( - "Failed to renew attestation of %r in %r: %s", user_id, group_id, e - ) - except Exception: - logger.exception( - "Error renewing attestation of %r in %r", user_id, group_id - ) - - for row in rows: - await _renew_attestation((row["group_id"], row["user_id"])) diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py deleted file mode 100644 index dfd24af695..0000000000 --- a/synapse/groups/groups_server.py +++ /dev/null @@ -1,1019 +0,0 @@ -# Copyright 2017 Vector Creations Ltd -# Copyright 2018 New Vector Ltd -# Copyright 2019 Michael Telatynski <7t3chguy@gmail.com> -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from typing import TYPE_CHECKING, Optional - -from synapse.api.errors import Codes, SynapseError -from synapse.handlers.groups_local import GroupsLocalHandler -from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN -from synapse.types import GroupID, JsonDict, RoomID, UserID, get_domain_from_id -from synapse.util.async_helpers import concurrently_execute - -if TYPE_CHECKING: - from synapse.server import HomeServer - -logger = logging.getLogger(__name__) - - -# TODO: Allow users to "knock" or simply join depending on rules -# TODO: Federation admin APIs -# TODO: is_privileged flag to users and is_public to users and rooms -# TODO: Audit log for admins (profile updates, membership changes, users who tried -# to join but were rejected, etc) -# TODO: Flairs - - -# Note that the maximum lengths are somewhat arbitrary. -MAX_SHORT_DESC_LEN = 1000 -MAX_LONG_DESC_LEN = 10000 - - -class GroupsServerWorkerHandler: - def __init__(self, hs: "HomeServer"): - self.hs = hs - self.store = hs.get_datastores().main - self.room_list_handler = hs.get_room_list_handler() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.keyring = hs.get_keyring() - self.is_mine_id = hs.is_mine_id - self.signing_key = hs.signing_key - self.server_name = hs.hostname - self.attestations = hs.get_groups_attestation_signing() - self.transport_client = hs.get_federation_transport_client() - self.profile_handler = hs.get_profile_handler() - - async def check_group_is_ours( - self, - group_id: str, - requester_user_id: str, - and_exists: bool = False, - and_is_admin: Optional[str] = None, - ) -> Optional[dict]: - """Check that the group is ours, and optionally if it exists. - - If group does exist then return group. - - Args: - group_id: The group ID to check. - requester_user_id: The user ID of the requester. - and_exists: whether to also check if group exists - and_is_admin: whether to also check if given str is a user_id - that is an admin - """ - if not self.is_mine_id(group_id): - raise SynapseError(400, "Group not on this server") - - group = await self.store.get_group(group_id) - if and_exists and not group: - raise SynapseError(404, "Unknown group") - - is_user_in_group = await self.store.is_user_in_group( - requester_user_id, group_id - ) - if group and not is_user_in_group and not group["is_public"]: - raise SynapseError(404, "Unknown group") - - if and_is_admin: - is_admin = await self.store.is_user_admin_in_group(group_id, and_is_admin) - if not is_admin: - raise SynapseError(403, "User is not admin in group") - - return group - - async def get_group_summary( - self, group_id: str, requester_user_id: str - ) -> JsonDict: - """Get the summary for a group as seen by requester_user_id. - - The group summary consists of the profile of the room, and a curated - list of users and rooms. These list *may* be organised by role/category. - The roles/categories are ordered, and so are the users/rooms within them. - - A user/room may appear in multiple roles/categories. - """ - await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) - - is_user_in_group = await self.store.is_user_in_group( - requester_user_id, group_id - ) - - profile = await self.get_group_profile(group_id, requester_user_id) - - users, roles = await self.store.get_users_for_summary_by_role( - group_id, include_private=is_user_in_group - ) - - # TODO: Add profiles to users - - rooms, categories = await self.store.get_rooms_for_summary_by_category( - group_id, include_private=is_user_in_group - ) - - for room_entry in rooms: - room_id = room_entry["room_id"] - joined_users = await self.store.get_users_in_room(room_id) - entry = await self.room_list_handler.generate_room_entry( - room_id, len(joined_users), with_alias=False, allow_private=True - ) - if entry is None: - continue - entry = dict(entry) # so we don't change what's cached - entry.pop("room_id", None) - - room_entry["profile"] = entry - - rooms.sort(key=lambda e: e.get("order", 0)) - - for user in users: - user_id = user["user_id"] - - if not self.is_mine_id(requester_user_id): - attestation = await self.store.get_remote_attestation(group_id, user_id) - if not attestation: - continue - - user["attestation"] = attestation - else: - user["attestation"] = self.attestations.create_attestation( - group_id, user_id - ) - - user_profile = await self.profile_handler.get_profile_from_cache(user_id) - user.update(user_profile) - - users.sort(key=lambda e: e.get("order", 0)) - - membership_info = await self.store.get_users_membership_info_in_group( - group_id, requester_user_id - ) - - return { - "profile": profile, - "users_section": { - "users": users, - "roles": roles, - "total_user_count_estimate": 0, # TODO - }, - "rooms_section": { - "rooms": rooms, - "categories": categories, - "total_room_count_estimate": 0, # TODO - }, - "user": membership_info, - } - - async def get_group_categories( - self, group_id: str, requester_user_id: str - ) -> JsonDict: - """Get all categories in a group (as seen by user)""" - await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) - - categories = await self.store.get_group_categories(group_id=group_id) - return {"categories": categories} - - async def get_group_category( - self, group_id: str, requester_user_id: str, category_id: str - ) -> JsonDict: - """Get a specific category in a group (as seen by user)""" - await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) - - return await self.store.get_group_category( - group_id=group_id, category_id=category_id - ) - - async def get_group_roles(self, group_id: str, requester_user_id: str) -> JsonDict: - """Get all roles in a group (as seen by user)""" - await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) - - roles = await self.store.get_group_roles(group_id=group_id) - return {"roles": roles} - - async def get_group_role( - self, group_id: str, requester_user_id: str, role_id: str - ) -> JsonDict: - """Get a specific role in a group (as seen by user)""" - await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) - - return await self.store.get_group_role(group_id=group_id, role_id=role_id) - - async def get_group_profile( - self, group_id: str, requester_user_id: str - ) -> JsonDict: - """Get the group profile as seen by requester_user_id""" - - await self.check_group_is_ours(group_id, requester_user_id) - - group = await self.store.get_group(group_id) - - if group: - cols = [ - "name", - "short_description", - "long_description", - "avatar_url", - "is_public", - ] - group_description = {key: group[key] for key in cols} - group_description["is_openly_joinable"] = group["join_policy"] == "open" - - return group_description - else: - raise SynapseError(404, "Unknown group") - - async def get_users_in_group( - self, group_id: str, requester_user_id: str - ) -> JsonDict: - """Get the users in group as seen by requester_user_id. - - The ordering is arbitrary at the moment - """ - - await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) - - is_user_in_group = await self.store.is_user_in_group( - requester_user_id, group_id - ) - - user_results = await self.store.get_users_in_group( - group_id, include_private=is_user_in_group - ) - - chunk = [] - for user_result in user_results: - g_user_id = user_result["user_id"] - is_public = user_result["is_public"] - is_privileged = user_result["is_admin"] - - entry = {"user_id": g_user_id} - - profile = await self.profile_handler.get_profile_from_cache(g_user_id) - entry.update(profile) - - entry["is_public"] = bool(is_public) - entry["is_privileged"] = bool(is_privileged) - - if not self.is_mine_id(g_user_id): - attestation = await self.store.get_remote_attestation( - group_id, g_user_id - ) - if not attestation: - continue - - entry["attestation"] = attestation - else: - entry["attestation"] = self.attestations.create_attestation( - group_id, g_user_id - ) - - chunk.append(entry) - - # TODO: If admin add lists of users whose attestations have timed out - - return {"chunk": chunk, "total_user_count_estimate": len(user_results)} - - async def get_invited_users_in_group( - self, group_id: str, requester_user_id: str - ) -> JsonDict: - """Get the users that have been invited to a group as seen by requester_user_id. - - The ordering is arbitrary at the moment - """ - - await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) - - is_user_in_group = await self.store.is_user_in_group( - requester_user_id, group_id - ) - - if not is_user_in_group: - raise SynapseError(403, "User not in group") - - invited_users = await self.store.get_invited_users_in_group(group_id) - - user_profiles = [] - - for user_id in invited_users: - user_profile = {"user_id": user_id} - try: - profile = await self.profile_handler.get_profile_from_cache(user_id) - user_profile.update(profile) - except Exception as e: - logger.warning("Error getting profile for %s: %s", user_id, e) - user_profiles.append(user_profile) - - return {"chunk": user_profiles, "total_user_count_estimate": len(invited_users)} - - async def get_rooms_in_group( - self, group_id: str, requester_user_id: str - ) -> JsonDict: - """Get the rooms in group as seen by requester_user_id - - This returns rooms in order of decreasing number of joined users - """ - - await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) - - is_user_in_group = await self.store.is_user_in_group( - requester_user_id, group_id - ) - - # Note! room_results["is_public"] is about whether the room is considered - # public from the group's point of view. (i.e. whether non-group members - # should be able to see the room is in the group). - # This is not the same as whether the room itself is public (in the sense - # of being visible in the room directory). - # As such, room_results["is_public"] itself is not sufficient to determine - # whether any given user is permitted to see the room's metadata. - room_results = await self.store.get_rooms_in_group( - group_id, include_private=is_user_in_group - ) - - chunk = [] - for room_result in room_results: - room_id = room_result["room_id"] - - joined_users = await self.store.get_users_in_room(room_id) - - # check the user is actually allowed to see the room before showing it to them - allow_private = requester_user_id in joined_users - - entry = await self.room_list_handler.generate_room_entry( - room_id, - len(joined_users), - with_alias=False, - allow_private=allow_private, - ) - - if not entry: - continue - - entry["is_public"] = bool(room_result["is_public"]) - - chunk.append(entry) - - chunk.sort(key=lambda e: -e["num_joined_members"]) - - return {"chunk": chunk, "total_room_count_estimate": len(chunk)} - - -class GroupsServerHandler(GroupsServerWorkerHandler): - def __init__(self, hs: "HomeServer"): - super().__init__(hs) - - # Ensure attestations get renewed - hs.get_groups_attestation_renewer() - - async def update_group_summary_room( - self, - group_id: str, - requester_user_id: str, - room_id: str, - category_id: str, - content: JsonDict, - ) -> JsonDict: - """Add/update a room to the group summary""" - await self.check_group_is_ours( - group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id - ) - - RoomID.from_string(room_id) # Ensure valid room id - - order = content.get("order", None) - - is_public = _parse_visibility_from_contents(content) - - await self.store.add_room_to_summary( - group_id=group_id, - room_id=room_id, - category_id=category_id, - order=order, - is_public=is_public, - ) - - return {} - - async def delete_group_summary_room( - self, group_id: str, requester_user_id: str, room_id: str, category_id: str - ) -> JsonDict: - """Remove a room from the summary""" - await self.check_group_is_ours( - group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id - ) - - await self.store.remove_room_from_summary( - group_id=group_id, room_id=room_id, category_id=category_id - ) - - return {} - - async def set_group_join_policy( - self, group_id: str, requester_user_id: str, content: JsonDict - ) -> JsonDict: - """Sets the group join policy. - - Currently supported policies are: - - "invite": an invite must be received and accepted in order to join. - - "open": anyone can join. - """ - await self.check_group_is_ours( - group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id - ) - - join_policy = _parse_join_policy_from_contents(content) - if join_policy is None: - raise SynapseError(400, "No value specified for 'm.join_policy'") - - await self.store.set_group_join_policy(group_id, join_policy=join_policy) - - return {} - - async def update_group_category( - self, group_id: str, requester_user_id: str, category_id: str, content: JsonDict - ) -> JsonDict: - """Add/Update a group category""" - await self.check_group_is_ours( - group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id - ) - - is_public = _parse_visibility_from_contents(content) - profile = content.get("profile") - - await self.store.upsert_group_category( - group_id=group_id, - category_id=category_id, - is_public=is_public, - profile=profile, - ) - - return {} - - async def delete_group_category( - self, group_id: str, requester_user_id: str, category_id: str - ) -> JsonDict: - """Delete a group category""" - await self.check_group_is_ours( - group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id - ) - - await self.store.remove_group_category( - group_id=group_id, category_id=category_id - ) - - return {} - - async def update_group_role( - self, group_id: str, requester_user_id: str, role_id: str, content: JsonDict - ) -> JsonDict: - """Add/update a role in a group""" - await self.check_group_is_ours( - group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id - ) - - is_public = _parse_visibility_from_contents(content) - - profile = content.get("profile") - - await self.store.upsert_group_role( - group_id=group_id, role_id=role_id, is_public=is_public, profile=profile - ) - - return {} - - async def delete_group_role( - self, group_id: str, requester_user_id: str, role_id: str - ) -> JsonDict: - """Remove role from group""" - await self.check_group_is_ours( - group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id - ) - - await self.store.remove_group_role(group_id=group_id, role_id=role_id) - - return {} - - async def update_group_summary_user( - self, - group_id: str, - requester_user_id: str, - user_id: str, - role_id: str, - content: JsonDict, - ) -> JsonDict: - """Add/update a users entry in the group summary""" - await self.check_group_is_ours( - group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id - ) - - order = content.get("order", None) - - is_public = _parse_visibility_from_contents(content) - - await self.store.add_user_to_summary( - group_id=group_id, - user_id=user_id, - role_id=role_id, - order=order, - is_public=is_public, - ) - - return {} - - async def delete_group_summary_user( - self, group_id: str, requester_user_id: str, user_id: str, role_id: str - ) -> JsonDict: - """Remove a user from the group summary""" - await self.check_group_is_ours( - group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id - ) - - await self.store.remove_user_from_summary( - group_id=group_id, user_id=user_id, role_id=role_id - ) - - return {} - - async def update_group_profile( - self, group_id: str, requester_user_id: str, content: JsonDict - ) -> None: - """Update the group profile""" - await self.check_group_is_ours( - group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id - ) - - profile = {} - for keyname, max_length in ( - ("name", MAX_DISPLAYNAME_LEN), - ("avatar_url", MAX_AVATAR_URL_LEN), - ("short_description", MAX_SHORT_DESC_LEN), - ("long_description", MAX_LONG_DESC_LEN), - ): - if keyname in content: - value = content[keyname] - if not isinstance(value, str): - raise SynapseError( - 400, - "%r value is not a string" % (keyname,), - errcode=Codes.INVALID_PARAM, - ) - if len(value) > max_length: - raise SynapseError( - 400, - "Invalid %s parameter" % (keyname,), - errcode=Codes.INVALID_PARAM, - ) - profile[keyname] = value - - await self.store.update_group_profile(group_id, profile) - - async def add_room_to_group( - self, group_id: str, requester_user_id: str, room_id: str, content: JsonDict - ) -> JsonDict: - """Add room to group""" - RoomID.from_string(room_id) # Ensure valid room id - - await self.check_group_is_ours( - group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id - ) - - is_public = _parse_visibility_from_contents(content) - - await self.store.add_room_to_group(group_id, room_id, is_public=is_public) - - return {} - - async def update_room_in_group( - self, - group_id: str, - requester_user_id: str, - room_id: str, - config_key: str, - content: JsonDict, - ) -> JsonDict: - """Update room in group""" - RoomID.from_string(room_id) # Ensure valid room id - - await self.check_group_is_ours( - group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id - ) - - if config_key == "m.visibility": - is_public = _parse_visibility_dict(content) - - await self.store.update_room_in_group_visibility( - group_id, room_id, is_public=is_public - ) - else: - raise SynapseError(400, "Unknown config option") - - return {} - - async def remove_room_from_group( - self, group_id: str, requester_user_id: str, room_id: str - ) -> JsonDict: - """Remove room from group""" - await self.check_group_is_ours( - group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id - ) - - await self.store.remove_room_from_group(group_id, room_id) - - return {} - - async def invite_to_group( - self, group_id: str, user_id: str, requester_user_id: str, content: JsonDict - ) -> JsonDict: - """Invite user to group""" - - group = await self.check_group_is_ours( - group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id - ) - if not group: - raise SynapseError(400, "Group does not exist", errcode=Codes.BAD_STATE) - - # TODO: Check if user knocked - - invited_users = await self.store.get_invited_users_in_group(group_id) - if user_id in invited_users: - raise SynapseError( - 400, "User already invited to group", errcode=Codes.BAD_STATE - ) - - user_results = await self.store.get_users_in_group( - group_id, include_private=True - ) - if user_id in (user_result["user_id"] for user_result in user_results): - raise SynapseError(400, "User already in group") - - content = { - "profile": {"name": group["name"], "avatar_url": group["avatar_url"]}, - "inviter": requester_user_id, - } - - if self.hs.is_mine_id(user_id): - groups_local = self.hs.get_groups_local_handler() - assert isinstance( - groups_local, GroupsLocalHandler - ), "Workers cannot invites users to groups." - res = await groups_local.on_invite(group_id, user_id, content) - local_attestation = None - else: - local_attestation = self.attestations.create_attestation(group_id, user_id) - content.update({"attestation": local_attestation}) - - res = await self.transport_client.invite_to_group_notification( - get_domain_from_id(user_id), group_id, user_id, content - ) - - user_profile = res.get("user_profile", {}) - await self.store.add_remote_profile_cache( - user_id, - displayname=user_profile.get("displayname"), - avatar_url=user_profile.get("avatar_url"), - ) - - if res["state"] == "join": - if not self.hs.is_mine_id(user_id): - remote_attestation = res["attestation"] - - await self.attestations.verify_attestation( - remote_attestation, user_id=user_id, group_id=group_id - ) - else: - remote_attestation = None - - await self.store.add_user_to_group( - group_id, - user_id, - is_admin=False, - is_public=False, # TODO - local_attestation=local_attestation, - remote_attestation=remote_attestation, - ) - return {"state": "join"} - elif res["state"] == "invite": - await self.store.add_group_invite(group_id, user_id) - return {"state": "invite"} - elif res["state"] == "reject": - return {"state": "reject"} - else: - raise SynapseError(502, "Unknown state returned by HS") - - async def _add_user( - self, group_id: str, user_id: str, content: JsonDict - ) -> Optional[JsonDict]: - """Add a user to a group based on a content dict. - - See accept_invite, join_group. - """ - if not self.hs.is_mine_id(user_id): - local_attestation: Optional[ - JsonDict - ] = self.attestations.create_attestation(group_id, user_id) - - remote_attestation = content["attestation"] - - await self.attestations.verify_attestation( - remote_attestation, user_id=user_id, group_id=group_id - ) - else: - local_attestation = None - remote_attestation = None - - is_public = _parse_visibility_from_contents(content) - - await self.store.add_user_to_group( - group_id, - user_id, - is_admin=False, - is_public=is_public, - local_attestation=local_attestation, - remote_attestation=remote_attestation, - ) - - return local_attestation - - async def accept_invite( - self, group_id: str, requester_user_id: str, content: JsonDict - ) -> JsonDict: - """User tries to accept an invite to the group. - - This is different from them asking to join, and so should error if no - invite exists (and they're not a member of the group) - """ - - await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) - - is_invited = await self.store.is_user_invited_to_local_group( - group_id, requester_user_id - ) - if not is_invited: - raise SynapseError(403, "User not invited to group") - - local_attestation = await self._add_user(group_id, requester_user_id, content) - - return {"state": "join", "attestation": local_attestation} - - async def join_group( - self, group_id: str, requester_user_id: str, content: JsonDict - ) -> JsonDict: - """User tries to join the group. - - This will error if the group requires an invite/knock to join - """ - - group_info = await self.check_group_is_ours( - group_id, requester_user_id, and_exists=True - ) - if not group_info: - raise SynapseError(404, "Group does not exist", errcode=Codes.NOT_FOUND) - if group_info["join_policy"] != "open": - raise SynapseError(403, "Group is not publicly joinable") - - local_attestation = await self._add_user(group_id, requester_user_id, content) - - return {"state": "join", "attestation": local_attestation} - - async def remove_user_from_group( - self, group_id: str, user_id: str, requester_user_id: str, content: JsonDict - ) -> JsonDict: - """Remove a user from the group; either a user is leaving or an admin - kicked them. - """ - - await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) - - is_kick = False - if requester_user_id != user_id: - is_admin = await self.store.is_user_admin_in_group( - group_id, requester_user_id - ) - if not is_admin: - raise SynapseError(403, "User is not admin in group") - - is_kick = True - - await self.store.remove_user_from_group(group_id, user_id) - - if is_kick: - if self.hs.is_mine_id(user_id): - groups_local = self.hs.get_groups_local_handler() - assert isinstance( - groups_local, GroupsLocalHandler - ), "Workers cannot remove users from groups." - await groups_local.user_removed_from_group(group_id, user_id, {}) - else: - await self.transport_client.remove_user_from_group_notification( - get_domain_from_id(user_id), group_id, user_id, {} - ) - - if not self.hs.is_mine_id(user_id): - await self.store.maybe_delete_remote_profile_cache(user_id) - - # Delete group if the last user has left - users = await self.store.get_users_in_group(group_id, include_private=True) - if not users: - await self.store.delete_group(group_id) - - return {} - - async def create_group( - self, group_id: str, requester_user_id: str, content: JsonDict - ) -> JsonDict: - logger.info("Attempting to create group with ID: %r", group_id) - - # parsing the id into a GroupID validates it. - group_id_obj = GroupID.from_string(group_id) - - group = await self.check_group_is_ours(group_id, requester_user_id) - if group: - raise SynapseError(400, "Group already exists") - - is_admin = await self.auth.is_server_admin( - UserID.from_string(requester_user_id) - ) - if not is_admin: - if not self.hs.config.groups.enable_group_creation: - raise SynapseError( - 403, "Only a server admin can create groups on this server" - ) - localpart = group_id_obj.localpart - if not localpart.startswith(self.hs.config.groups.group_creation_prefix): - raise SynapseError( - 400, - "Can only create groups with prefix %r on this server" - % (self.hs.config.groups.group_creation_prefix,), - ) - - profile = content.get("profile", {}) - name = profile.get("name") - avatar_url = profile.get("avatar_url") - short_description = profile.get("short_description") - long_description = profile.get("long_description") - user_profile = content.get("user_profile", {}) - - await self.store.create_group( - group_id, - requester_user_id, - name=name, - avatar_url=avatar_url, - short_description=short_description, - long_description=long_description, - ) - - if not self.hs.is_mine_id(requester_user_id): - remote_attestation = content["attestation"] - - await self.attestations.verify_attestation( - remote_attestation, user_id=requester_user_id, group_id=group_id - ) - - local_attestation: Optional[ - JsonDict - ] = self.attestations.create_attestation(group_id, requester_user_id) - else: - local_attestation = None - remote_attestation = None - - await self.store.add_user_to_group( - group_id, - requester_user_id, - is_admin=True, - is_public=True, # TODO - local_attestation=local_attestation, - remote_attestation=remote_attestation, - ) - - if not self.hs.is_mine_id(requester_user_id): - await self.store.add_remote_profile_cache( - requester_user_id, - displayname=user_profile.get("displayname"), - avatar_url=user_profile.get("avatar_url"), - ) - - return {"group_id": group_id} - - async def delete_group(self, group_id: str, requester_user_id: str) -> None: - """Deletes a group, kicking out all current members. - - Only group admins or server admins can call this request - - Args: - group_id: The group ID to delete. - requester_user_id: The user requesting to delete the group. - """ - - await self.check_group_is_ours(group_id, requester_user_id, and_exists=True) - - # Only server admins or group admins can delete groups. - - is_admin = await self.store.is_user_admin_in_group(group_id, requester_user_id) - - if not is_admin: - is_admin = await self.auth.is_server_admin( - UserID.from_string(requester_user_id) - ) - - if not is_admin: - raise SynapseError(403, "User is not an admin") - - # Before deleting the group lets kick everyone out of it - users = await self.store.get_users_in_group(group_id, include_private=True) - - async def _kick_user_from_group(user_id: str) -> None: - if self.hs.is_mine_id(user_id): - groups_local = self.hs.get_groups_local_handler() - assert isinstance( - groups_local, GroupsLocalHandler - ), "Workers cannot kick users from groups." - await groups_local.user_removed_from_group(group_id, user_id, {}) - else: - await self.transport_client.remove_user_from_group_notification( - get_domain_from_id(user_id), group_id, user_id, {} - ) - await self.store.maybe_delete_remote_profile_cache(user_id) - - # We kick users out in the order of: - # 1. Non-admins - # 2. Other admins - # 3. The requester - # - # This is so that if the deletion fails for some reason other admins or - # the requester still has auth to retry. - non_admins = [] - admins = [] - for u in users: - if u["user_id"] == requester_user_id: - continue - if u["is_admin"]: - admins.append(u["user_id"]) - else: - non_admins.append(u["user_id"]) - - await concurrently_execute(_kick_user_from_group, non_admins, 10) - await concurrently_execute(_kick_user_from_group, admins, 10) - await _kick_user_from_group(requester_user_id) - - await self.store.delete_group(group_id) - - -def _parse_join_policy_from_contents(content: JsonDict) -> Optional[str]: - """Given a content for a request, return the specified join policy or None""" - - join_policy_dict = content.get("m.join_policy") - if join_policy_dict: - return _parse_join_policy_dict(join_policy_dict) - else: - return None - - -def _parse_join_policy_dict(join_policy_dict: JsonDict) -> str: - """Given a dict for the "m.join_policy" config return the join policy specified""" - join_policy_type = join_policy_dict.get("type") - if not join_policy_type: - return "invite" - - if join_policy_type not in ("invite", "open"): - raise SynapseError(400, "Synapse only supports 'invite'/'open' join rule") - return join_policy_type - - -def _parse_visibility_from_contents(content: JsonDict) -> bool: - """Given a content for a request parse out whether the entity should be - public or not - """ - - visibility = content.get("m.visibility") - if visibility: - return _parse_visibility_dict(visibility) - else: - is_public = True - - return is_public - - -def _parse_visibility_dict(visibility: JsonDict) -> bool: - """Given a dict for the "m.visibility" config return if the entity should - be public or not - """ - vis_type = visibility.get("type") - if not vis_type: - return True - - if vis_type not in ("public", "private"): - raise SynapseError(400, "Synapse only supports 'public'/'private' visibility") - return vis_type == "public" diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index 96376963f2..d4fe7df533 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -30,8 +30,8 @@ logger = logging.getLogger(__name__) class AdminHandler: def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main - self.storage = hs.get_storage() - self.state_store = self.storage.state + self._storage_controllers = hs.get_storage_controllers() + self._state_storage_controller = self._storage_controllers.state async def get_whois(self, user: UserID) -> JsonDict: connections = [] @@ -197,7 +197,9 @@ class AdminHandler: from_key = events[-1].internal_metadata.after - events = await filter_events_for_client(self.storage, user_id, events) + events = await filter_events_for_client( + self._storage_controllers, user_id, events + ) writer.write_events(room_id, events) @@ -233,7 +235,9 @@ class AdminHandler: for event_id in extremities: if not event_to_unseen_prevs[event_id]: continue - state = await self.state_store.get_state_for_event(event_id) + state = await self._state_storage_controller.get_state_for_event( + event_id + ) writer.write_state(room_id, event_id, state) return writer.finished() diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 1da7bcc85b..814553e098 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -19,7 +19,7 @@ from prometheus_client import Counter from twisted.internet import defer import synapse -from synapse.api.constants import EventTypes +from synapse.api.constants import EduTypes, EventTypes from synapse.appservice import ApplicationService from synapse.events import EventBase from synapse.handlers.presence import format_user_presence_state @@ -503,7 +503,7 @@ class ApplicationServicesHandler: time_now = self.clock.time_msec() events.extend( { - "type": "m.presence", + "type": EduTypes.PRESENCE, "sender": event.user_id, "content": format_user_presence_state( event, time_now, include_user_id=False diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 1d6d1f8a92..72faf2ee38 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -28,7 +28,7 @@ from typing import ( ) from synapse.api import errors -from synapse.api.constants import EventTypes +from synapse.api.constants import EduTypes, EventTypes from synapse.api.errors import ( Codes, FederationDeniedError, @@ -61,6 +61,7 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) MAX_DEVICE_DISPLAY_NAME_LEN = 100 +DELETE_STALE_DEVICES_INTERVAL_MS = 24 * 60 * 60 * 1000 class DeviceWorkerHandler: @@ -70,7 +71,7 @@ class DeviceWorkerHandler: self.store = hs.get_datastores().main self.notifier = hs.get_notifier() self.state = hs.get_state_handler() - self.state_store = hs.get_storage().state + self._state_storage = hs.get_storage_controllers().state self._auth_handler = hs.get_auth_handler() self.server_name = hs.hostname @@ -203,7 +204,9 @@ class DeviceWorkerHandler: continue # mapping from event_id -> state_dict - prev_state_ids = await self.state_store.get_state_ids_for_events(event_ids) + prev_state_ids = await self._state_storage.get_state_ids_for_events( + event_ids + ) # Check if we've joined the room? If so we just blindly add all the users to # the "possibly changed" users. @@ -277,7 +280,8 @@ class DeviceHandler(DeviceWorkerHandler): federation_registry = hs.get_federation_registry() federation_registry.register_edu_handler( - "m.device_list_update", self.device_list_updater.incoming_device_list_update + EduTypes.DEVICE_LIST_UPDATE, + self.device_list_updater.incoming_device_list_update, ) hs.get_distributor().observe("user_left_room", self.user_left_room) @@ -292,6 +296,19 @@ class DeviceHandler(DeviceWorkerHandler): # On start up check if there are any updates pending. hs.get_reactor().callWhenRunning(self._handle_new_device_update_async) + self._delete_stale_devices_after = hs.config.server.delete_stale_devices_after + + # Ideally we would run this on a worker and condition this on the + # "run_background_tasks_on" setting, but this would mean making the notification + # of device list changes over federation work on workers, which is nontrivial. + if self._delete_stale_devices_after is not None: + self.clock.looping_call( + run_as_background_process, + DELETE_STALE_DEVICES_INTERVAL_MS, + "delete_stale_devices", + self._delete_stale_devices, + ) + def _check_device_name_length(self, name: Optional[str]) -> None: """ Checks whether a device name is longer than the maximum allowed length. @@ -367,6 +384,19 @@ class DeviceHandler(DeviceWorkerHandler): raise errors.StoreError(500, "Couldn't generate a device ID.") + async def _delete_stale_devices(self) -> None: + """Background task that deletes devices which haven't been accessed for more than + a configured time period. + """ + # We should only be running this job if the config option is defined. + assert self._delete_stale_devices_after is not None + now_ms = self.clock.time_msec() + since_ms = now_ms - self._delete_stale_devices_after + devices = await self.store.get_local_devices_not_accessed_since(since_ms) + + for user_id, user_devices in devices.items(): + await self.delete_devices(user_id, user_devices) + @trace async def delete_device(self, user_id: str, device_id: str) -> None: """Delete the given device @@ -689,7 +719,8 @@ class DeviceHandler(DeviceWorkerHandler): ) # TODO: when called, this isn't in a logging context. # This leads to log spam, sentry event spam, and massive - # memory usage. See #12552. + # memory usage. + # See https://github.com/matrix-org/synapse/issues/12552. # log_kv( # {"message": "sent device update to host", "host": host} # ) @@ -763,6 +794,10 @@ class DeviceListUpdater: device_id = edu_content.pop("device_id") stream_id = str(edu_content.pop("stream_id")) # They may come as ints prev_ids = edu_content.pop("prev_id", []) + if not isinstance(prev_ids, list): + raise SynapseError( + 400, "Device list update had an invalid 'prev_ids' field" + ) prev_ids = [str(p) for p in prev_ids] # They may come as ints if get_domain_from_id(user_id) != origin: diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index 53668cce3b..444c08bc2e 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -15,7 +15,7 @@ import logging from typing import TYPE_CHECKING, Any, Dict -from synapse.api.constants import ToDeviceEventTypes +from synapse.api.constants import EduTypes, ToDeviceEventTypes from synapse.api.errors import SynapseError from synapse.api.ratelimiting import Ratelimiter from synapse.logging.context import run_in_background @@ -59,11 +59,11 @@ class DeviceMessageHandler: # to the appropriate worker. if hs.get_instance_name() in hs.config.worker.writers.to_device: hs.get_federation_registry().register_edu_handler( - "m.direct_to_device", self.on_direct_to_device_edu + EduTypes.DIRECT_TO_DEVICE, self.on_direct_to_device_edu ) else: hs.get_federation_registry().register_instances_for_edu( - "m.direct_to_device", + EduTypes.DIRECT_TO_DEVICE, hs.config.worker.writers.to_device, ) diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index e6c2cfb8c8..52bb5c9c55 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -25,6 +25,7 @@ from unpaddedbase64 import decode_base64 from twisted.internet import defer +from synapse.api.constants import EduTypes from synapse.api.errors import CodeMessageException, Codes, NotFoundError, SynapseError from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace @@ -66,13 +67,13 @@ class E2eKeysHandler: # Only register this edu handler on master as it requires writing # device updates to the db federation_registry.register_edu_handler( - "m.signing_key_update", + EduTypes.SIGNING_KEY_UPDATE, self._edu_updater.incoming_signing_key_update, ) # also handle the unstable version # FIXME: remove this when enough servers have upgraded federation_registry.register_edu_handler( - "org.matrix.signing_key_update", + EduTypes.UNSTABLE_SIGNING_KEY_UPDATE, self._edu_updater.incoming_signing_key_update, ) diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 82a5aac3dd..ac13340d3a 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -113,7 +113,7 @@ class EventStreamHandler: states = await presence_handler.get_states(users) to_add.extend( { - "type": EduTypes.Presence, + "type": EduTypes.PRESENCE, "content": format_user_presence_state(state, time_now), } for state in states @@ -139,7 +139,7 @@ class EventStreamHandler: class EventHandler: def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main - self.storage = hs.get_storage() + self._storage_controllers = hs.get_storage_controllers() async def get_event( self, @@ -177,7 +177,7 @@ class EventHandler: is_peeking = user.to_string() not in users filtered = await filter_events_for_client( - self.storage, user.to_string(), [event], is_peeking=is_peeking + self._storage_controllers, user.to_string(), [event], is_peeking=is_peeking ) if not filtered: diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 0386d0a07b..80ee7e7b4e 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -125,8 +125,8 @@ class FederationHandler: self.hs = hs self.store = hs.get_datastores().main - self.storage = hs.get_storage() - self.state_store = self.storage.state + self._storage_controllers = hs.get_storage_controllers() + self._state_storage_controller = self._storage_controllers.state self.federation_client = hs.get_federation_client() self.state_handler = hs.get_state_handler() self.server_name = hs.hostname @@ -324,7 +324,7 @@ class FederationHandler: # We set `check_history_visibility_only` as we might otherwise get false # positives from users having been erased. filtered_extremities = await filter_events_for_server( - self.storage, + self._storage_controllers, self.server_name, events_to_check, redact=False, @@ -660,7 +660,7 @@ class FederationHandler: # in the invitee's sync stream. It is stripped out for all other local users. event.unsigned["knock_room_state"] = stripped_room_state["knock_state_events"] - context = EventContext.for_outlier(self.storage) + context = EventContext.for_outlier(self._storage_controllers) stream_id = await self._federation_event_handler.persist_events_and_notify( event.room_id, [(event, context)] ) @@ -849,7 +849,7 @@ class FederationHandler: ) ) - context = EventContext.for_outlier(self.storage) + context = EventContext.for_outlier(self._storage_controllers) await self._federation_event_handler.persist_events_and_notify( event.room_id, [(event, context)] ) @@ -878,7 +878,7 @@ class FederationHandler: await self.federation_client.send_leave(host_list, event) - context = EventContext.for_outlier(self.storage) + context = EventContext.for_outlier(self._storage_controllers) stream_id = await self._federation_event_handler.persist_events_and_notify( event.room_id, [(event, context)] ) @@ -1027,7 +1027,9 @@ class FederationHandler: if event.internal_metadata.outlier: raise NotFoundError("State not known at event %s" % (event_id,)) - state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id]) + state_groups = await self._state_storage_controller.get_state_groups_ids( + room_id, [event_id] + ) # get_state_groups_ids should return exactly one result assert len(state_groups) == 1 @@ -1076,7 +1078,9 @@ class FederationHandler: ], ) - events = await filter_events_for_server(self.storage, origin, events) + events = await filter_events_for_server( + self._storage_controllers, origin, events + ) return events @@ -1107,7 +1111,9 @@ class FederationHandler: if not in_room: raise AuthError(403, "Host not in room.") - events = await filter_events_for_server(self.storage, origin, [event]) + events = await filter_events_for_server( + self._storage_controllers, origin, [event] + ) event = events[0] return event else: @@ -1136,7 +1142,7 @@ class FederationHandler: ) missing_events = await filter_events_for_server( - self.storage, origin, missing_events + self._storage_controllers, origin, missing_events ) return missing_events @@ -1478,9 +1484,11 @@ class FederationHandler: # clear the lazy-loading flag. logger.info("Updating current state for %s", room_id) assert ( - self.storage.persistence is not None + self._storage_controllers.persistence is not None ), "TODO(faster_joins): support for workers" - await self.storage.persistence.update_current_state(room_id) + await self._storage_controllers.persistence.update_current_state( + room_id + ) logger.info("Clearing partial-state flag for %s", room_id) success = await self.store.clear_partial_state_room(room_id) diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index ca82df8a6d..b908674529 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -98,8 +98,8 @@ class FederationEventHandler: def __init__(self, hs: "HomeServer"): self._store = hs.get_datastores().main - self._storage = hs.get_storage() - self._state_store = self._storage.state + self._storage_controllers = hs.get_storage_controllers() + self._state_storage_controller = self._storage_controllers.state self._state_handler = hs.get_state_handler() self._event_creation_handler = hs.get_event_creation_handler() @@ -274,7 +274,7 @@ class FederationEventHandler: affected=pdu.event_id, ) - await self._process_received_pdu(origin, pdu, state=None) + await self._process_received_pdu(origin, pdu, state_ids=None) async def on_send_membership_event( self, origin: str, event: EventBase @@ -463,7 +463,9 @@ class FederationEventHandler: with nested_logging_context(suffix=event.event_id): context = await self._state_handler.compute_event_context( event, - old_state=state, + state_ids_before_event={ + (e.type, e.state_key): e.event_id for e in state + }, partial_state=partial_state, ) @@ -512,12 +514,12 @@ class FederationEventHandler: # # This is the same operation as we do when we receive a regular event # over federation. - state = await self._resolve_state_at_missing_prevs(destination, event) + state_ids = await self._resolve_state_at_missing_prevs(destination, event) # build a new state group for it if need be context = await self._state_handler.compute_event_context( event, - old_state=state, + state_ids_before_event=state_ids, ) if context.partial_state: # this can happen if some or all of the event's prev_events still have @@ -533,7 +535,9 @@ class FederationEventHandler: ) return await self._store.update_state_for_partial_state_event(event, context) - self._state_store.notify_event_un_partial_stated(event.event_id) + self._state_storage_controller.notify_event_un_partial_stated( + event.event_id + ) async def backfill( self, dest: str, room_id: str, limit: int, extremities: Collection[str] @@ -767,11 +771,12 @@ class FederationEventHandler: return try: - state = await self._resolve_state_at_missing_prevs(origin, event) + state_ids = await self._resolve_state_at_missing_prevs(origin, event) # TODO(faster_joins): make sure that _resolve_state_at_missing_prevs does # not return partial state + await self._process_received_pdu( - origin, event, state=state, backfilled=backfilled + origin, event, state_ids=state_ids, backfilled=backfilled ) except FederationError as e: if e.code == 403: @@ -781,7 +786,7 @@ class FederationEventHandler: async def _resolve_state_at_missing_prevs( self, dest: str, event: EventBase - ) -> Optional[Iterable[EventBase]]: + ) -> Optional[StateMap[str]]: """Calculate the state at an event with missing prev_events. This is used when we have pulled a batch of events from a remote server, and @@ -808,8 +813,8 @@ class FederationEventHandler: event: an event to check for missing prevs. Returns: - if we already had all the prev events, `None`. Otherwise, returns a list of - the events in the state at `event`. + if we already had all the prev events, `None`. Otherwise, returns + the event ids of the state at `event`. """ room_id = event.room_id event_id = event.event_id @@ -829,10 +834,12 @@ class FederationEventHandler: ) # Calculate the state after each of the previous events, and # resolve them to find the correct state at the current event. - event_map = {event_id: event} + try: # Get the state of the events we know about - ours = await self._state_store.get_state_groups_ids(room_id, seen) + ours = await self._state_storage_controller.get_state_groups_ids( + room_id, seen + ) # state_maps is a list of mappings from (type, state_key) to event_id state_maps: List[StateMap[str]] = list(ours.values()) @@ -849,40 +856,23 @@ class FederationEventHandler: # note that if any of the missing prevs share missing state or # auth events, the requests to fetch those events are deduped # by the get_pdu_cache in federation_client. - remote_state = await self._get_state_after_missing_prev_event( - dest, room_id, p + remote_state_map = ( + await self._get_state_ids_after_missing_prev_event( + dest, room_id, p + ) ) - remote_state_map = { - (x.type, x.state_key): x.event_id for x in remote_state - } state_maps.append(remote_state_map) - for x in remote_state: - event_map[x.event_id] = x - room_version = await self._store.get_room_version_id(room_id) state_map = await self._state_resolution_handler.resolve_events_with_store( room_id, room_version, state_maps, - event_map, + event_map={event_id: event}, state_res_store=StateResolutionStore(self._store), ) - # We need to give _process_received_pdu the actual state events - # rather than event ids, so generate that now. - - # First though we need to fetch all the events that are in - # state_map, so we can build up the state below. - evs = await self._store.get_events( - list(state_map.values()), - get_prev_content=False, - redact_behaviour=EventRedactBehaviour.as_is, - ) - event_map.update(evs) - - state = [event_map[e] for e in state_map.values()] except Exception: logger.warning( "Error attempting to resolve state at missing prev_events", @@ -894,14 +884,14 @@ class FederationEventHandler: "We can't get valid state history.", affected=event_id, ) - return state + return state_map - async def _get_state_after_missing_prev_event( + async def _get_state_ids_after_missing_prev_event( self, destination: str, room_id: str, event_id: str, - ) -> List[EventBase]: + ) -> StateMap[str]: """Requests all of the room state at a given event from a remote homeserver. Args: @@ -910,7 +900,7 @@ class FederationEventHandler: event_id: The id of the event we want the state at. Returns: - A list of events in the state, including the event itself + The event ids of the state *after* the given event. """ ( state_event_ids, @@ -925,19 +915,17 @@ class FederationEventHandler: len(auth_event_ids), ) - # start by just trying to fetch the events from the store + # Start by checking events we already have in the DB desired_events = set(state_event_ids) desired_events.add(event_id) logger.debug("Fetching %i events from cache/store", len(desired_events)) - fetched_events = await self._store.get_events( - desired_events, allow_rejected=True - ) + have_events = await self._store.have_seen_events(room_id, desired_events) - missing_desired_events = desired_events - fetched_events.keys() + missing_desired_events = desired_events - have_events logger.debug( "We are missing %i events (got %i)", len(missing_desired_events), - len(fetched_events), + len(have_events), ) # We probably won't need most of the auth events, so let's just check which @@ -948,7 +936,7 @@ class FederationEventHandler: # already have a bunch of the state events. It would be nice if the # federation api gave us a way of finding out which we actually need. - missing_auth_events = set(auth_event_ids) - fetched_events.keys() + missing_auth_events = set(auth_event_ids) - have_events missing_auth_events.difference_update( await self._store.have_seen_events(room_id, missing_auth_events) ) @@ -974,47 +962,51 @@ class FederationEventHandler: destination=destination, room_id=room_id, event_ids=missing_events ) - # we need to make sure we re-load from the database to get the rejected - # state correct. - fetched_events.update( - await self._store.get_events(missing_desired_events, allow_rejected=True) - ) + # We now need to fill out the state map, which involves fetching the + # type and state key for each event ID in the state. + state_map = {} - # check for events which were in the wrong room. - # - # this can happen if a remote server claims that the state or - # auth_events at an event in room A are actually events in room B - - bad_events = [ - (event_id, event.room_id) - for event_id, event in fetched_events.items() - if event.room_id != room_id - ] + event_metadata = await self._store.get_metadata_for_events(state_event_ids) + for state_event_id, metadata in event_metadata.items(): + if metadata.room_id != room_id: + # This is a bogus situation, but since we may only discover it a long time + # after it happened, we try our best to carry on, by just omitting the + # bad events from the returned state set. + # + # This can happen if a remote server claims that the state or + # auth_events at an event in room A are actually events in room B + logger.warning( + "Remote server %s claims event %s in room %s is an auth/state " + "event in room %s", + destination, + state_event_id, + metadata.room_id, + room_id, + ) + continue - for bad_event_id, bad_room_id in bad_events: - # This is a bogus situation, but since we may only discover it a long time - # after it happened, we try our best to carry on, by just omitting the - # bad events from the returned state set. - logger.warning( - "Remote server %s claims event %s in room %s is an auth/state " - "event in room %s", - destination, - bad_event_id, - bad_room_id, - room_id, - ) + if metadata.state_key is None: + logger.warning( + "Remote server gave us non-state event in state: %s", state_event_id + ) + continue - del fetched_events[bad_event_id] + state_map[(metadata.event_type, metadata.state_key)] = state_event_id # if we couldn't get the prev event in question, that's a problem. - remote_event = fetched_events.get(event_id) + remote_event = await self._store.get_event( + event_id, + allow_none=True, + allow_rejected=True, + redact_behaviour=EventRedactBehaviour.as_is, + ) if not remote_event: raise Exception("Unable to get missing prev_event %s" % (event_id,)) # missing state at that event is a warning, not a blocker # XXX: this doesn't sound right? it means that we'll end up with incomplete # state. - failed_to_fetch = desired_events - fetched_events.keys() + failed_to_fetch = desired_events - event_metadata.keys() if failed_to_fetch: logger.warning( "Failed to fetch missing state events for %s %s", @@ -1022,14 +1014,12 @@ class FederationEventHandler: failed_to_fetch, ) - remote_state = [ - fetched_events[e_id] for e_id in state_event_ids if e_id in fetched_events - ] - if remote_event.is_state() and remote_event.rejected_reason is None: - remote_state.append(remote_event) + state_map[ + (remote_event.type, remote_event.state_key) + ] = remote_event.event_id - return remote_state + return state_map async def _get_state_and_persist( self, destination: str, room_id: str, event_id: str @@ -1056,7 +1046,7 @@ class FederationEventHandler: self, origin: str, event: EventBase, - state: Optional[Iterable[EventBase]], + state_ids: Optional[StateMap[str]], backfilled: bool = False, ) -> None: """Called when we have a new non-outlier event. @@ -1078,7 +1068,7 @@ class FederationEventHandler: event: event to be persisted - state: Normally None, but if we are handling a gap in the graph + state_ids: Normally None, but if we are handling a gap in the graph (ie, we are missing one or more prev_events), the resolved state at the event @@ -1090,7 +1080,8 @@ class FederationEventHandler: try: context = await self._state_handler.compute_event_context( - event, old_state=state + event, + state_ids_before_event=state_ids, ) context = await self._check_event_auth( origin, @@ -1107,7 +1098,7 @@ class FederationEventHandler: # For new (non-backfilled and non-outlier) events we check if the event # passes auth based on the current state. If it doesn't then we # "soft-fail" the event. - await self._check_for_soft_fail(event, state, origin=origin) + await self._check_for_soft_fail(event, state_ids, origin=origin) await self._run_push_actions_and_persist_event(event, context, backfilled) @@ -1449,7 +1440,7 @@ class FederationEventHandler: # we're not bothering about room state, so flag the event as an outlier. event.internal_metadata.outlier = True - context = EventContext.for_outlier(self._storage) + context = EventContext.for_outlier(self._storage_controllers) try: validate_event_for_room_version(room_version_obj, event) check_auth_rules_for_event(room_version_obj, event, auth) @@ -1589,7 +1580,7 @@ class FederationEventHandler: async def _check_for_soft_fail( self, event: EventBase, - state: Optional[Iterable[EventBase]], + state_ids: Optional[StateMap[str]], origin: str, ) -> None: """Checks if we should soft fail the event; if so, marks the event as @@ -1597,7 +1588,7 @@ class FederationEventHandler: Args: event - state: The state at the event if we don't have all the event's prev events + state_ids: The state at the event if we don't have all the event's prev events origin: The host the event originates from. """ extrem_ids_list = await self._store.get_latest_event_ids_in_room(event.room_id) @@ -1613,7 +1604,7 @@ class FederationEventHandler: room_version_obj = KNOWN_ROOM_VERSIONS[room_version] # Calculate the "current state". - if state is not None: + if state_ids is not None: # If we're explicitly given the state then we won't have all the # prev events, and so we have a gap in the graph. In this case # we want to be a little careful as we might have been down for @@ -1626,17 +1617,20 @@ class FederationEventHandler: # given state at the event. This should correctly handle cases # like bans, especially with state res v2. - state_sets_d = await self._state_store.get_state_groups( + state_sets_d = await self._state_storage_controller.get_state_groups_ids( event.room_id, extrem_ids ) - state_sets: List[Iterable[EventBase]] = list(state_sets_d.values()) - state_sets.append(state) - current_states = await self._state_handler.resolve_events( - room_version, state_sets, event + state_sets: List[StateMap[str]] = list(state_sets_d.values()) + state_sets.append(state_ids) + current_state_ids = ( + await self._state_resolution_handler.resolve_events_with_store( + event.room_id, + room_version, + state_sets, + event_map=None, + state_res_store=StateResolutionStore(self._store), + ) ) - current_state_ids: StateMap[str] = { - k: e.event_id for k, e in current_states.items() - } else: current_state_ids = await self._state_handler.get_current_state_ids( event.room_id, latest_event_ids=extrem_ids @@ -1895,7 +1889,7 @@ class FederationEventHandler: # create a new state group as a delta from the existing one. prev_group = context.state_group - state_group = await self._state_store.store_state_group( + state_group = await self._state_storage_controller.store_state_group( event.event_id, event.room_id, prev_group=prev_group, @@ -1904,7 +1898,7 @@ class FederationEventHandler: ) return EventContext.with_state( - storage=self._storage, + storage=self._storage_controllers, state_group=state_group, state_group_before_event=context.state_group_before_event, state_delta_due_to_event=state_updates, @@ -1994,11 +1988,14 @@ class FederationEventHandler: ) return result["max_stream_id"] else: - assert self._storage.persistence + assert self._storage_controllers.persistence # Note that this returns the events that were persisted, which may not be # the same as were passed in if some were deduplicated due to transaction IDs. - events, max_stream_token = await self._storage.persistence.persist_events( + ( + events, + max_stream_token, + ) = await self._storage_controllers.persistence.persist_events( event_and_contexts, backfilled=backfilled ) diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py deleted file mode 100644 index e7a399787b..0000000000 --- a/synapse/handlers/groups_local.py +++ /dev/null @@ -1,503 +0,0 @@ -# Copyright 2017 Vector Creations Ltd -# Copyright 2018 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Iterable, List, Set - -from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError -from synapse.types import GroupID, JsonDict, get_domain_from_id - -if TYPE_CHECKING: - from synapse.server import HomeServer - -logger = logging.getLogger(__name__) - - -def _create_rerouter(func_name: str) -> Callable[..., Awaitable[JsonDict]]: - """Returns an async function that looks at the group id and calls the function - on federation or the local group server if the group is local - """ - - async def f( - self: "GroupsLocalWorkerHandler", group_id: str, *args: Any, **kwargs: Any - ) -> JsonDict: - if not GroupID.is_valid(group_id): - raise SynapseError(400, "%s is not a legal group ID" % (group_id,)) - - if self.is_mine_id(group_id): - return await getattr(self.groups_server_handler, func_name)( - group_id, *args, **kwargs - ) - else: - destination = get_domain_from_id(group_id) - - try: - return await getattr(self.transport_client, func_name)( - destination, group_id, *args, **kwargs - ) - except HttpResponseException as e: - # Capture errors returned by the remote homeserver and - # re-throw specific errors as SynapseErrors. This is so - # when the remote end responds with things like 403 Not - # In Group, we can communicate that to the client instead - # of a 500. - raise e.to_synapse_error() - except RequestSendFailed: - raise SynapseError(502, "Failed to contact group server") - - return f - - -class GroupsLocalWorkerHandler: - def __init__(self, hs: "HomeServer"): - self.hs = hs - self.store = hs.get_datastores().main - self.room_list_handler = hs.get_room_list_handler() - self.groups_server_handler = hs.get_groups_server_handler() - self.transport_client = hs.get_federation_transport_client() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.keyring = hs.get_keyring() - self.is_mine_id = hs.is_mine_id - self.signing_key = hs.signing_key - self.server_name = hs.hostname - self.notifier = hs.get_notifier() - self.attestations = hs.get_groups_attestation_signing() - - self.profile_handler = hs.get_profile_handler() - - # The following functions merely route the query to the local groups server - # or federation depending on if the group is local or remote - - get_group_profile = _create_rerouter("get_group_profile") - get_rooms_in_group = _create_rerouter("get_rooms_in_group") - get_invited_users_in_group = _create_rerouter("get_invited_users_in_group") - get_group_category = _create_rerouter("get_group_category") - get_group_categories = _create_rerouter("get_group_categories") - get_group_role = _create_rerouter("get_group_role") - get_group_roles = _create_rerouter("get_group_roles") - - async def get_group_summary( - self, group_id: str, requester_user_id: str - ) -> JsonDict: - """Get the group summary for a group. - - If the group is remote we check that the users have valid attestations. - """ - if self.is_mine_id(group_id): - res = await self.groups_server_handler.get_group_summary( - group_id, requester_user_id - ) - else: - try: - res = await self.transport_client.get_group_summary( - get_domain_from_id(group_id), group_id, requester_user_id - ) - except HttpResponseException as e: - raise e.to_synapse_error() - except RequestSendFailed: - raise SynapseError(502, "Failed to contact group server") - - group_server_name = get_domain_from_id(group_id) - - # Loop through the users and validate the attestations. - chunk = res["users_section"]["users"] - valid_users = [] - for entry in chunk: - g_user_id = entry["user_id"] - attestation = entry.pop("attestation", {}) - try: - if get_domain_from_id(g_user_id) != group_server_name: - await self.attestations.verify_attestation( - attestation, - group_id=group_id, - user_id=g_user_id, - server_name=get_domain_from_id(g_user_id), - ) - valid_users.append(entry) - except Exception as e: - logger.info("Failed to verify user is in group: %s", e) - - res["users_section"]["users"] = valid_users - - res["users_section"]["users"].sort(key=lambda e: e.get("order", 0)) - res["rooms_section"]["rooms"].sort(key=lambda e: e.get("order", 0)) - - # Add `is_publicised` flag to indicate whether the user has publicised their - # membership of the group on their profile - result = await self.store.get_publicised_groups_for_user(requester_user_id) - is_publicised = group_id in result - - res.setdefault("user", {})["is_publicised"] = is_publicised - - return res - - async def get_users_in_group( - self, group_id: str, requester_user_id: str - ) -> JsonDict: - """Get users in a group""" - if self.is_mine_id(group_id): - return await self.groups_server_handler.get_users_in_group( - group_id, requester_user_id - ) - - group_server_name = get_domain_from_id(group_id) - - try: - res = await self.transport_client.get_users_in_group( - get_domain_from_id(group_id), group_id, requester_user_id - ) - except HttpResponseException as e: - raise e.to_synapse_error() - except RequestSendFailed: - raise SynapseError(502, "Failed to contact group server") - - chunk = res["chunk"] - valid_entries = [] - for entry in chunk: - g_user_id = entry["user_id"] - attestation = entry.pop("attestation", {}) - try: - if get_domain_from_id(g_user_id) != group_server_name: - await self.attestations.verify_attestation( - attestation, - group_id=group_id, - user_id=g_user_id, - server_name=get_domain_from_id(g_user_id), - ) - valid_entries.append(entry) - except Exception as e: - logger.info("Failed to verify user is in group: %s", e) - - res["chunk"] = valid_entries - - return res - - async def get_joined_groups(self, user_id: str) -> JsonDict: - group_ids = await self.store.get_joined_groups(user_id) - return {"groups": group_ids} - - async def get_publicised_groups_for_user(self, user_id: str) -> JsonDict: - if self.hs.is_mine_id(user_id): - result = await self.store.get_publicised_groups_for_user(user_id) - - # Check AS associated groups for this user - this depends on the - # RegExps in the AS registration file (under `users`) - for app_service in self.store.get_app_services(): - result.extend(app_service.get_groups_for_user(user_id)) - - return {"groups": result} - else: - try: - bulk_result = await self.transport_client.bulk_get_publicised_groups( - get_domain_from_id(user_id), [user_id] - ) - except HttpResponseException as e: - raise e.to_synapse_error() - except RequestSendFailed: - raise SynapseError(502, "Failed to contact group server") - - result = bulk_result.get("users", {}).get(user_id) - # TODO: Verify attestations - return {"groups": result} - - async def bulk_get_publicised_groups( - self, user_ids: Iterable[str], proxy: bool = True - ) -> JsonDict: - destinations: Dict[str, Set[str]] = {} - local_users = set() - - for user_id in user_ids: - if self.hs.is_mine_id(user_id): - local_users.add(user_id) - else: - destinations.setdefault(get_domain_from_id(user_id), set()).add(user_id) - - if not proxy and destinations: - raise SynapseError(400, "Some user_ids are not local") - - results = {} - failed_results: List[str] = [] - for destination, dest_user_ids in destinations.items(): - try: - r = await self.transport_client.bulk_get_publicised_groups( - destination, list(dest_user_ids) - ) - results.update(r["users"]) - except Exception: - failed_results.extend(dest_user_ids) - - for uid in local_users: - results[uid] = await self.store.get_publicised_groups_for_user(uid) - - # Check AS associated groups for this user - this depends on the - # RegExps in the AS registration file (under `users`) - for app_service in self.store.get_app_services(): - results[uid].extend(app_service.get_groups_for_user(uid)) - - return {"users": results} - - -class GroupsLocalHandler(GroupsLocalWorkerHandler): - def __init__(self, hs: "HomeServer"): - super().__init__(hs) - - # Ensure attestations get renewed - hs.get_groups_attestation_renewer() - - # The following functions merely route the query to the local groups server - # or federation depending on if the group is local or remote - - update_group_profile = _create_rerouter("update_group_profile") - - add_room_to_group = _create_rerouter("add_room_to_group") - update_room_in_group = _create_rerouter("update_room_in_group") - remove_room_from_group = _create_rerouter("remove_room_from_group") - - update_group_summary_room = _create_rerouter("update_group_summary_room") - delete_group_summary_room = _create_rerouter("delete_group_summary_room") - - update_group_category = _create_rerouter("update_group_category") - delete_group_category = _create_rerouter("delete_group_category") - - update_group_summary_user = _create_rerouter("update_group_summary_user") - delete_group_summary_user = _create_rerouter("delete_group_summary_user") - - update_group_role = _create_rerouter("update_group_role") - delete_group_role = _create_rerouter("delete_group_role") - - set_group_join_policy = _create_rerouter("set_group_join_policy") - - async def create_group( - self, group_id: str, user_id: str, content: JsonDict - ) -> JsonDict: - """Create a group""" - - logger.info("Asking to create group with ID: %r", group_id) - - if self.is_mine_id(group_id): - res = await self.groups_server_handler.create_group( - group_id, user_id, content - ) - local_attestation = None - remote_attestation = None - else: - raise SynapseError(400, "Unable to create remote groups") - - is_publicised = content.get("publicise", False) - token = await self.store.register_user_group_membership( - group_id, - user_id, - membership="join", - is_admin=True, - local_attestation=local_attestation, - remote_attestation=remote_attestation, - is_publicised=is_publicised, - ) - self.notifier.on_new_event("groups_key", token, users=[user_id]) - - return res - - async def join_group( - self, group_id: str, user_id: str, content: JsonDict - ) -> JsonDict: - """Request to join a group""" - if self.is_mine_id(group_id): - await self.groups_server_handler.join_group(group_id, user_id, content) - local_attestation = None - remote_attestation = None - else: - local_attestation = self.attestations.create_attestation(group_id, user_id) - content["attestation"] = local_attestation - - try: - res = await self.transport_client.join_group( - get_domain_from_id(group_id), group_id, user_id, content - ) - except HttpResponseException as e: - raise e.to_synapse_error() - except RequestSendFailed: - raise SynapseError(502, "Failed to contact group server") - - remote_attestation = res["attestation"] - - await self.attestations.verify_attestation( - remote_attestation, - group_id=group_id, - user_id=user_id, - server_name=get_domain_from_id(group_id), - ) - - # TODO: Check that the group is public and we're being added publicly - is_publicised = content.get("publicise", False) - - token = await self.store.register_user_group_membership( - group_id, - user_id, - membership="join", - is_admin=False, - local_attestation=local_attestation, - remote_attestation=remote_attestation, - is_publicised=is_publicised, - ) - self.notifier.on_new_event("groups_key", token, users=[user_id]) - - return {} - - async def accept_invite( - self, group_id: str, user_id: str, content: JsonDict - ) -> JsonDict: - """Accept an invite to a group""" - if self.is_mine_id(group_id): - await self.groups_server_handler.accept_invite(group_id, user_id, content) - local_attestation = None - remote_attestation = None - else: - local_attestation = self.attestations.create_attestation(group_id, user_id) - content["attestation"] = local_attestation - - try: - res = await self.transport_client.accept_group_invite( - get_domain_from_id(group_id), group_id, user_id, content - ) - except HttpResponseException as e: - raise e.to_synapse_error() - except RequestSendFailed: - raise SynapseError(502, "Failed to contact group server") - - remote_attestation = res["attestation"] - - await self.attestations.verify_attestation( - remote_attestation, - group_id=group_id, - user_id=user_id, - server_name=get_domain_from_id(group_id), - ) - - # TODO: Check that the group is public and we're being added publicly - is_publicised = content.get("publicise", False) - - token = await self.store.register_user_group_membership( - group_id, - user_id, - membership="join", - is_admin=False, - local_attestation=local_attestation, - remote_attestation=remote_attestation, - is_publicised=is_publicised, - ) - self.notifier.on_new_event("groups_key", token, users=[user_id]) - - return {} - - async def invite( - self, group_id: str, user_id: str, requester_user_id: str, config: JsonDict - ) -> JsonDict: - """Invite a user to a group""" - content = {"requester_user_id": requester_user_id, "config": config} - if self.is_mine_id(group_id): - res = await self.groups_server_handler.invite_to_group( - group_id, user_id, requester_user_id, content - ) - else: - try: - res = await self.transport_client.invite_to_group( - get_domain_from_id(group_id), - group_id, - user_id, - requester_user_id, - content, - ) - except HttpResponseException as e: - raise e.to_synapse_error() - except RequestSendFailed: - raise SynapseError(502, "Failed to contact group server") - - return res - - async def on_invite( - self, group_id: str, user_id: str, content: JsonDict - ) -> JsonDict: - """One of our users were invited to a group""" - # TODO: Support auto join and rejection - - if not self.is_mine_id(user_id): - raise SynapseError(400, "User not on this server") - - local_profile = {} - if "profile" in content: - if "name" in content["profile"]: - local_profile["name"] = content["profile"]["name"] - if "avatar_url" in content["profile"]: - local_profile["avatar_url"] = content["profile"]["avatar_url"] - - token = await self.store.register_user_group_membership( - group_id, - user_id, - membership="invite", - content={"profile": local_profile, "inviter": content["inviter"]}, - ) - self.notifier.on_new_event("groups_key", token, users=[user_id]) - try: - user_profile = await self.profile_handler.get_profile(user_id) - except Exception as e: - logger.warning("No profile for user %s: %s", user_id, e) - user_profile = {} - - return {"state": "invite", "user_profile": user_profile} - - async def remove_user_from_group( - self, group_id: str, user_id: str, requester_user_id: str, content: JsonDict - ) -> JsonDict: - """Remove a user from a group""" - if user_id == requester_user_id: - token = await self.store.register_user_group_membership( - group_id, user_id, membership="leave" - ) - self.notifier.on_new_event("groups_key", token, users=[user_id]) - - # TODO: Should probably remember that we tried to leave so that we can - # retry if the group server is currently down. - - if self.is_mine_id(group_id): - res = await self.groups_server_handler.remove_user_from_group( - group_id, user_id, requester_user_id, content - ) - else: - content["requester_user_id"] = requester_user_id - try: - res = await self.transport_client.remove_user_from_group( - get_domain_from_id(group_id), - group_id, - requester_user_id, - user_id, - content, - ) - except HttpResponseException as e: - raise e.to_synapse_error() - except RequestSendFailed: - raise SynapseError(502, "Failed to contact group server") - - return res - - async def user_removed_from_group( - self, group_id: str, user_id: str, content: JsonDict - ) -> None: - """One of our users was removed/kicked from a group""" - # TODO: Check if user in group - token = await self.store.register_user_group_membership( - group_id, user_id, membership="leave" - ) - self.notifier.on_new_event("groups_key", token, users=[user_id]) diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index d79248ad90..d2b489e816 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -67,8 +67,8 @@ class InitialSyncHandler: ] ] = ResponseCache(hs.get_clock(), "initial_sync_cache") self._event_serializer = hs.get_event_client_serializer() - self.storage = hs.get_storage() - self.state_store = self.storage.state + self._storage_controllers = hs.get_storage_controllers() + self._state_storage_controller = self._storage_controllers.state async def snapshot_all_rooms( self, @@ -198,7 +198,8 @@ class InitialSyncHandler: event.stream_ordering, ) deferred_room_state = run_in_background( - self.state_store.get_state_for_events, [event.event_id] + self._state_storage_controller.get_state_for_events, + [event.event_id], ).addCallback( lambda states: cast(StateMap[EventBase], states[event.event_id]) ) @@ -218,7 +219,7 @@ class InitialSyncHandler: ).addErrback(unwrapFirstError) messages = await filter_events_for_client( - self.storage, user_id, messages + self._storage_controllers, user_id, messages ) start_token = now_token.copy_and_replace(StreamKeyType.ROOM, token) @@ -274,7 +275,7 @@ class InitialSyncHandler: "rooms": rooms_ret, "presence": [ { - "type": "m.presence", + "type": EduTypes.PRESENCE, "content": format_user_presence_state(event, now), } for event in presence @@ -355,7 +356,9 @@ class InitialSyncHandler: member_event_id: str, is_peeking: bool, ) -> JsonDict: - room_state = await self.state_store.get_state_for_event(member_event_id) + room_state = await self._state_storage_controller.get_state_for_event( + member_event_id + ) limit = pagin_config.limit if pagin_config else None if limit is None: @@ -369,7 +372,7 @@ class InitialSyncHandler: ) messages = await filter_events_for_client( - self.storage, user_id, messages, is_peeking=is_peeking + self._storage_controllers, user_id, messages, is_peeking=is_peeking ) start_token = StreamToken.START.copy_and_replace(StreamKeyType.ROOM, token) @@ -439,7 +442,7 @@ class InitialSyncHandler: return [ { - "type": EduTypes.Presence, + "type": EduTypes.PRESENCE, "content": format_user_presence_state(s, time_now), } for s in states @@ -474,7 +477,7 @@ class InitialSyncHandler: ) messages = await filter_events_for_client( - self.storage, user_id, messages, is_peeking=is_peeking + self._storage_controllers, user_id, messages, is_peeking=is_peeking ) start_token = now_token.copy_and_replace(StreamKeyType.ROOM, token) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 22cdad3f33..cf7c2d1979 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -54,7 +54,14 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.http.send_event import ReplicationSendEventRestServlet from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.state import StateFilter -from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester +from synapse.types import ( + MutableStateMap, + Requester, + RoomAlias, + StreamToken, + UserID, + create_requester, +) from synapse.util import json_decoder, json_encoder, log_failure, unwrapFirstError from synapse.util.async_helpers import Linearizer, gather_results from synapse.util.caches.expiringcache import ExpiringCache @@ -76,8 +83,8 @@ class MessageHandler: self.clock = hs.get_clock() self.state = hs.get_state_handler() self.store = hs.get_datastores().main - self.storage = hs.get_storage() - self.state_store = self.storage.state + self._storage_controllers = hs.get_storage_controllers() + self._state_storage_controller = self._storage_controllers.state self._event_serializer = hs.get_event_client_serializer() self._ephemeral_events_enabled = hs.config.server.enable_ephemeral_messages @@ -124,7 +131,7 @@ class MessageHandler: assert ( membership_event_id is not None ), "check_user_in_room_or_world_readable returned invalid data" - room_state = await self.state_store.get_state_for_events( + room_state = await self._state_storage_controller.get_state_for_events( [membership_event_id], StateFilter.from_types([key]) ) data = room_state[membership_event_id].get(key) @@ -185,7 +192,7 @@ class MessageHandler: # check whether the user is in the room at that time to determine # whether they should be treated as peeking. - state_map = await self.state_store.get_state_for_event( + state_map = await self._state_storage_controller.get_state_for_event( last_event.event_id, StateFilter.from_types([(EventTypes.Member, user_id)]), ) @@ -198,7 +205,7 @@ class MessageHandler: is_peeking = not joined visible_events = await filter_events_for_client( - self.storage, + self._storage_controllers, user_id, [last_event], filter_send_to_client=False, @@ -206,8 +213,10 @@ class MessageHandler: ) if visible_events: - room_state_events = await self.state_store.get_state_for_events( - [last_event.event_id], state_filter=state_filter + room_state_events = ( + await self._state_storage_controller.get_state_for_events( + [last_event.event_id], state_filter=state_filter + ) ) room_state: Mapping[Any, EventBase] = room_state_events[ last_event.event_id @@ -236,8 +245,10 @@ class MessageHandler: assert ( membership_event_id is not None ), "check_user_in_room_or_world_readable returned invalid data" - room_state_events = await self.state_store.get_state_for_events( - [membership_event_id], state_filter=state_filter + room_state_events = ( + await self._state_storage_controller.get_state_for_events( + [membership_event_id], state_filter=state_filter + ) ) room_state = room_state_events[membership_event_id] @@ -394,7 +405,7 @@ class EventCreationHandler: self.auth = hs.get_auth() self._event_auth_handler = hs.get_event_auth_handler() self.store = hs.get_datastores().main - self.storage = hs.get_storage() + self._storage_controllers = hs.get_storage_controllers() self.state = hs.get_state_handler() self.clock = hs.get_clock() self.validator = EventValidator() @@ -887,6 +898,22 @@ class EventCreationHandler: spam_check_result = await self.spam_checker.check_event_for_spam(event) if spam_check_result != self.spam_checker.NOT_SPAM: + if isinstance(spam_check_result, tuple): + try: + [code, dict] = spam_check_result + raise SynapseError( + 403, + "This message had been rejected as probable spam", + code, + dict, + ) + except ValueError: + logger.error( + "Spam-check module returned invalid error value. Expecting [code, dict], got %s", + spam_check_result, + ) + spam_check_result = Codes.FORBIDDEN + if isinstance(spam_check_result, Codes): raise SynapseError( 403, @@ -1021,7 +1048,7 @@ class EventCreationHandler: # after it is created if builder.internal_metadata.outlier: event.internal_metadata.outlier = True - context = EventContext.for_outlier(self.storage) + context = EventContext.for_outlier(self._storage_controllers) elif ( event.type == EventTypes.MSC2716_INSERTION and state_event_ids @@ -1033,8 +1060,35 @@ class EventCreationHandler: # # TODO(faster_joins): figure out how this works, and make sure that the # old state is complete. - old_state = await self.store.get_events_as_list(state_event_ids) - context = await self.state.compute_event_context(event, old_state=old_state) + metadata = await self.store.get_metadata_for_events(state_event_ids) + + state_map_for_event: MutableStateMap[str] = {} + for state_id in state_event_ids: + data = metadata.get(state_id) + if data is None: + # We're trying to persist a new historical batch of events + # with the given state, e.g. via + # `RoomBatchSendEventRestServlet`. The state can be inferred + # by Synapse or set directly by the client. + # + # Either way, we should have persisted all the state before + # getting here. + raise Exception( + f"State event {state_id} not found in DB," + " Synapse should have persisted it before using it." + ) + + if data.state_key is None: + raise Exception( + f"Trying to set non-state event {state_id} as state" + ) + + state_map_for_event[(data.event_type, data.state_key)] = state_id + + context = await self.state.compute_event_context( + event, + state_ids_before_event=state_map_for_event, + ) else: context = await self.state.compute_event_context(event) @@ -1407,7 +1461,7 @@ class EventCreationHandler: """ extra_users = extra_users or [] - assert self.storage.persistence is not None + assert self._storage_controllers.persistence is not None assert self._events_shard_config.should_handle( self._instance_name, event.room_id ) @@ -1641,7 +1695,7 @@ class EventCreationHandler: event, event_pos, max_stream_token, - ) = await self.storage.persistence.persist_event( + ) = await self._storage_controllers.persistence.persist_event( event, context=context, backfilled=backfilled ) diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 19a4407050..6262a35822 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -129,8 +129,8 @@ class PaginationHandler: self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastores().main - self.storage = hs.get_storage() - self.state_store = self.storage.state + self._storage_controllers = hs.get_storage_controllers() + self._state_storage_controller = self._storage_controllers.state self.clock = hs.get_clock() self._server_name = hs.hostname self._room_shutdown_handler = hs.get_room_shutdown_handler() @@ -352,7 +352,7 @@ class PaginationHandler: self._purges_in_progress_by_room.add(room_id) try: async with self.pagination_lock.write(room_id): - await self.storage.purge_events.purge_history( + await self._storage_controllers.purge_events.purge_history( room_id, token, delete_local_events ) logger.info("[purge] complete") @@ -414,7 +414,7 @@ class PaginationHandler: if joined: raise SynapseError(400, "Users are still joined to this room") - await self.storage.purge_events.purge_room(room_id) + await self._storage_controllers.purge_events.purge_room(room_id) async def get_messages( self, @@ -515,14 +515,28 @@ class PaginationHandler: next_token = from_token.copy_and_replace(StreamKeyType.ROOM, next_key) - if events: - if event_filter: - events = await event_filter.filter(events) + # if no events are returned from pagination, that implies + # we have reached the end of the available events. + # In that case we do not return end, to tell the client + # there is no need for further queries. + if not events: + return { + "chunk": [], + "start": await from_token.to_string(self.store), + } - events = await filter_events_for_client( - self.storage, user_id, events, is_peeking=(member_event_id is None) - ) + if event_filter: + events = await event_filter.filter(events) + + events = await filter_events_for_client( + self._storage_controllers, + user_id, + events, + is_peeking=(member_event_id is None), + ) + # if after the filter applied there are no more events + # return immediately - but there might be more in next_token batch if not events: return { "chunk": [], @@ -539,7 +553,7 @@ class PaginationHandler: (EventTypes.Member, event.sender) for event in events ) - state_ids = await self.state_store.get_state_ids_for_event( + state_ids = await self._state_storage_controller.get_state_ids_for_event( events[0].event_id, state_filter=state_filter ) @@ -653,7 +667,7 @@ class PaginationHandler: 400, "Users are still joined to this room" ) - await self.storage.purge_events.purge_room(room_id) + await self._storage_controllers.purge_events.purge_room(room_id) logger.info("complete") self._delete_by_id[delete_id].status = DeleteStatus.STATUS_COMPLETE diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index dd84e6c88b..bf112b9e1e 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -49,7 +49,7 @@ from prometheus_client import Counter from typing_extensions import ContextManager import synapse.metrics -from synapse.api.constants import EventTypes, Membership, PresenceState +from synapse.api.constants import EduTypes, EventTypes, Membership, PresenceState from synapse.api.errors import SynapseError from synapse.api.presence import UserPresenceState from synapse.appservice import ApplicationService @@ -394,7 +394,7 @@ class WorkerPresenceHandler(BasePresenceHandler): # Route presence EDUs to the right worker hs.get_federation_registry().register_instances_for_edu( - "m.presence", + EduTypes.PRESENCE, hs.config.worker.writers.presence, ) @@ -649,7 +649,9 @@ class PresenceHandler(BasePresenceHandler): federation_registry = hs.get_federation_registry() - federation_registry.register_edu_handler("m.presence", self.incoming_presence) + federation_registry.register_edu_handler( + EduTypes.PRESENCE, self.incoming_presence + ) LaterGauge( "synapse_handlers_presence_user_to_current_state_size", diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index e6a35f1d09..43d2882b0a 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -14,7 +14,7 @@ import logging from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple -from synapse.api.constants import ReceiptTypes +from synapse.api.constants import EduTypes, ReceiptTypes from synapse.appservice import ApplicationService from synapse.streams import EventSource from synapse.types import ( @@ -52,11 +52,11 @@ class ReceiptsHandler: # to the appropriate worker. if hs.get_instance_name() in hs.config.worker.writers.receipts: hs.get_federation_registry().register_edu_handler( - "m.receipt", self._received_remote_receipt + EduTypes.RECEIPT, self._received_remote_receipt ) else: hs.get_federation_registry().register_instances_for_edu( - "m.receipt", + EduTypes.RECEIPT, hs.config.worker.writers.receipts, ) diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index ab7e54857d..9a1cc11bb3 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -69,7 +69,7 @@ class BundledAggregations: class RelationsHandler: def __init__(self, hs: "HomeServer"): self._main_store = hs.get_datastores().main - self._storage = hs.get_storage() + self._storage_controllers = hs.get_storage_controllers() self._auth = hs.get_auth() self._clock = hs.get_clock() self._event_handler = hs.get_event_handler() @@ -143,7 +143,10 @@ class RelationsHandler: ) events = await filter_events_for_client( - self._storage, user_id, events, is_peeking=(member_event_id is None) + self._storage_controllers, + user_id, + events, + is_peeking=(member_event_id is None), ) now = self._clock.time_msec() diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 92e1de0500..5c91d33f58 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1192,8 +1192,8 @@ class RoomContextHandler: self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastores().main - self.storage = hs.get_storage() - self.state_store = self.storage.state + self._storage_controllers = hs.get_storage_controllers() + self._state_storage_controller = self._storage_controllers.state self._relations_handler = hs.get_relations_handler() async def get_event_context( @@ -1236,7 +1236,10 @@ class RoomContextHandler: if use_admin_priviledge: return events return await filter_events_for_client( - self.storage, user.to_string(), events, is_peeking=is_peeking + self._storage_controllers, + user.to_string(), + events, + is_peeking=is_peeking, ) event = await self.store.get_event( @@ -1293,7 +1296,7 @@ class RoomContextHandler: # first? Shouldn't we be consistent with /sync? # https://github.com/matrix-org/matrix-doc/issues/687 - state = await self.state_store.get_state_for_events( + state = await self._state_storage_controller.get_state_for_events( [last_event_id], state_filter=state_filter ) diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py index fbfd748406..1414e575d6 100644 --- a/synapse/handlers/room_batch.py +++ b/synapse/handlers/room_batch.py @@ -17,7 +17,7 @@ class RoomBatchHandler: def __init__(self, hs: "HomeServer"): self.hs = hs self.store = hs.get_datastores().main - self.state_store = hs.get_storage().state + self._state_storage_controller = hs.get_storage_controllers().state self.event_creation_handler = hs.get_event_creation_handler() self.room_member_handler = hs.get_room_member_handler() self.auth = hs.get_auth() @@ -141,7 +141,7 @@ class RoomBatchHandler: ) = await self.store.get_max_depth_of(event_ids) # mapping from (type, state_key) -> state_event_id assert most_recent_event_id is not None - prev_state_map = await self.state_store.get_state_ids_for_event( + prev_state_map = await self._state_storage_controller.get_state_ids_for_event( most_recent_event_id ) # List of state event ID's diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index ea876c168d..00662dc961 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -1081,17 +1081,6 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # Transfer alias mappings in the room directory await self.store.update_aliases_for_room(old_room_id, room_id) - # Check if any groups we own contain the predecessor room - local_group_ids = await self.store.get_local_groups_for_room(old_room_id) - for group_id in local_group_ids: - # Add new the new room to those groups - await self.store.add_room_to_group( - group_id, room_id, old_room is not None and old_room["is_public"] - ) - - # Remove the old room from those groups - await self.store.remove_room_from_group(group_id, old_room_id) - async def copy_user_state_on_room_upgrade( self, old_room_id: str, new_room_id: str, user_ids: Iterable[str] ) -> None: diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py index af83de3193..75aee6a111 100644 --- a/synapse/handlers/room_summary.py +++ b/synapse/handlers/room_summary.py @@ -662,7 +662,8 @@ class RoomSummaryHandler: # The API doesn't return the room version so assume that a # join rule of knock is valid. if ( - room.get("join_rules") in (JoinRules.PUBLIC, JoinRules.KNOCK) + room.get("join_rule") + in (JoinRules.PUBLIC, JoinRules.KNOCK, JoinRules.KNOCK_RESTRICTED) or room.get("world_readable") is True ): return True @@ -713,9 +714,6 @@ class RoomSummaryHandler: "canonical_alias": stats["canonical_alias"], "num_joined_members": stats["joined_members"], "avatar_url": stats["avatar"], - # plural join_rules is a documentation error but kept for historical - # purposes. Should match /publicRooms. - "join_rules": stats["join_rules"], "join_rule": stats["join_rules"], "world_readable": ( stats["history_visibility"] == HistoryVisibility.WORLD_READABLE diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index cd1c47dae8..659f99f7e2 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -55,8 +55,8 @@ class SearchHandler: self.hs = hs self._event_serializer = hs.get_event_client_serializer() self._relations_handler = hs.get_relations_handler() - self.storage = hs.get_storage() - self.state_store = self.storage.state + self._storage_controllers = hs.get_storage_controllers() + self._state_storage_controller = self._storage_controllers.state self.auth = hs.get_auth() async def get_old_rooms_from_upgraded_room(self, room_id: str) -> Iterable[str]: @@ -460,7 +460,7 @@ class SearchHandler: filtered_events = await search_filter.filter([r["event"] for r in results]) events = await filter_events_for_client( - self.storage, user.to_string(), filtered_events + self._storage_controllers, user.to_string(), filtered_events ) events.sort(key=lambda e: -rank_map[e.event_id]) @@ -559,7 +559,7 @@ class SearchHandler: filtered_events = await search_filter.filter([r["event"] for r in results]) events = await filter_events_for_client( - self.storage, user.to_string(), filtered_events + self._storage_controllers, user.to_string(), filtered_events ) room_events.extend(events) @@ -644,11 +644,11 @@ class SearchHandler: ) events_before = await filter_events_for_client( - self.storage, user.to_string(), res.events_before + self._storage_controllers, user.to_string(), res.events_before ) events_after = await filter_events_for_client( - self.storage, user.to_string(), res.events_after + self._storage_controllers, user.to_string(), res.events_after ) context: JsonDict = { @@ -677,7 +677,7 @@ class SearchHandler: [(EventTypes.Member, sender) for sender in senders] ) - state = await self.state_store.get_state_for_event( + state = await self._state_storage_controller.get_state_for_event( last_event_id, state_filter ) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 59b5d497be..b5859dcb28 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -166,16 +166,6 @@ class KnockedSyncResult: return True -@attr.s(slots=True, frozen=True, auto_attribs=True) -class GroupsSyncResult: - join: JsonDict - invite: JsonDict - leave: JsonDict - - def __bool__(self) -> bool: - return bool(self.join or self.invite or self.leave) - - @attr.s(slots=True, auto_attribs=True) class _RoomChanges: """The set of room entries to include in the sync, plus the set of joined @@ -206,7 +196,6 @@ class SyncResult: for this device device_unused_fallback_key_types: List of key types that have an unused fallback key - groups: Group updates, if any """ next_batch: StreamToken @@ -220,7 +209,6 @@ class SyncResult: device_lists: DeviceListUpdates device_one_time_keys_count: JsonDict device_unused_fallback_key_types: List[str] - groups: Optional[GroupsSyncResult] def __bool__(self) -> bool: """Make the result appear empty if there are no updates. This is used @@ -236,7 +224,6 @@ class SyncResult: or self.account_data or self.to_device or self.device_lists - or self.groups ) @@ -251,8 +238,8 @@ class SyncHandler: self.clock = hs.get_clock() self.state = hs.get_state_handler() self.auth = hs.get_auth() - self.storage = hs.get_storage() - self.state_store = self.storage.state + self._storage_controllers = hs.get_storage_controllers() + self._state_storage_controller = self._storage_controllers.state # TODO: flush cache entries on subsequent sync request. # Once we get the next /sync request (ie, one with the same access token @@ -525,7 +512,7 @@ class SyncHandler: current_state_ids = frozenset(current_state_ids_map.values()) recents = await filter_events_for_client( - self.storage, + self._storage_controllers, sync_config.user.to_string(), recents, always_include_ids=current_state_ids, @@ -593,7 +580,7 @@ class SyncHandler: current_state_ids = frozenset(current_state_ids_map.values()) loaded_recents = await filter_events_for_client( - self.storage, + self._storage_controllers, sync_config.user.to_string(), loaded_recents, always_include_ids=current_state_ids, @@ -643,7 +630,7 @@ class SyncHandler: event: event of interest state_filter: The state filter used to fetch state from the database. """ - state_ids = await self.state_store.get_state_ids_for_event( + state_ids = await self._state_storage_controller.get_state_ids_for_event( event.event_id, state_filter=state_filter or StateFilter.all() ) if event.is_state(): @@ -723,7 +710,7 @@ class SyncHandler: return None last_event = last_events[-1] - state_ids = await self.state_store.get_state_ids_for_event( + state_ids = await self._state_storage_controller.get_state_ids_for_event( last_event.event_id, state_filter=StateFilter.from_types( [(EventTypes.Name, ""), (EventTypes.CanonicalAlias, "")] @@ -901,12 +888,16 @@ class SyncHandler: if full_state: if batch: - current_state_ids = await self.state_store.get_state_ids_for_event( - batch.events[-1].event_id, state_filter=state_filter + current_state_ids = ( + await self._state_storage_controller.get_state_ids_for_event( + batch.events[-1].event_id, state_filter=state_filter + ) ) - state_ids = await self.state_store.get_state_ids_for_event( - batch.events[0].event_id, state_filter=state_filter + state_ids = ( + await self._state_storage_controller.get_state_ids_for_event( + batch.events[0].event_id, state_filter=state_filter + ) ) else: @@ -926,7 +917,7 @@ class SyncHandler: elif batch.limited: if batch: state_at_timeline_start = ( - await self.state_store.get_state_ids_for_event( + await self._state_storage_controller.get_state_ids_for_event( batch.events[0].event_id, state_filter=state_filter ) ) @@ -960,8 +951,10 @@ class SyncHandler: ) if batch: - current_state_ids = await self.state_store.get_state_ids_for_event( - batch.events[-1].event_id, state_filter=state_filter + current_state_ids = ( + await self._state_storage_controller.get_state_ids_for_event( + batch.events[-1].event_id, state_filter=state_filter + ) ) else: # Its not clear how we get here, but empirically we do @@ -991,7 +984,7 @@ class SyncHandler: # So we fish out all the member events corresponding to the # timeline here, and then dedupe any redundant ones below. - state_ids = await self.state_store.get_state_ids_for_event( + state_ids = await self._state_storage_controller.get_state_ids_for_event( batch.events[0].event_id, # we only want members! state_filter=StateFilter.from_types( @@ -1157,10 +1150,6 @@ class SyncHandler: await self.store.get_e2e_unused_fallback_key_types(user_id, device_id) ) - if self.hs_config.experimental.groups_enabled: - logger.debug("Fetching group data") - await self._generate_sync_entry_for_groups(sync_result_builder) - num_events = 0 # debug for https://github.com/matrix-org/synapse/issues/9424 @@ -1184,57 +1173,11 @@ class SyncHandler: archived=sync_result_builder.archived, to_device=sync_result_builder.to_device, device_lists=device_lists, - groups=sync_result_builder.groups, device_one_time_keys_count=one_time_key_counts, device_unused_fallback_key_types=unused_fallback_key_types, next_batch=sync_result_builder.now_token, ) - @measure_func("_generate_sync_entry_for_groups") - async def _generate_sync_entry_for_groups( - self, sync_result_builder: "SyncResultBuilder" - ) -> None: - user_id = sync_result_builder.sync_config.user.to_string() - since_token = sync_result_builder.since_token - now_token = sync_result_builder.now_token - - if since_token and since_token.groups_key: - results = await self.store.get_groups_changes_for_user( - user_id, since_token.groups_key, now_token.groups_key - ) - else: - results = await self.store.get_all_groups_for_user( - user_id, now_token.groups_key - ) - - invited = {} - joined = {} - left = {} - for result in results: - membership = result["membership"] - group_id = result["group_id"] - gtype = result["type"] - content = result["content"] - - if membership == "join": - if gtype == "membership": - # TODO: Add profile - content.pop("membership", None) - joined[group_id] = content["content"] - else: - joined.setdefault(group_id, {})[gtype] = content - elif membership == "invite": - if gtype == "membership": - content.pop("membership", None) - invited[group_id] = content["content"] - else: - if gtype == "membership": - left[group_id] = content["content"] - - sync_result_builder.groups = GroupsSyncResult( - join=joined, invite=invited, leave=left - ) - @measure_func("_generate_sync_entry_for_device_list") async def _generate_sync_entry_for_device_list( self, @@ -2333,7 +2276,6 @@ class SyncResultBuilder: invited knocked archived - groups to_device """ @@ -2349,7 +2291,6 @@ class SyncResultBuilder: invited: List[InvitedSyncResult] = attr.Factory(list) knocked: List[KnockedSyncResult] = attr.Factory(list) archived: List[ArchivedSyncResult] = attr.Factory(list) - groups: Optional[GroupsSyncResult] = None to_device: List[JsonDict] = attr.Factory(list) def calculate_user_changes(self) -> Tuple[Set[str], Set[str]]: diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index bb00750bfd..0aeab86bbb 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -17,6 +17,7 @@ from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple import attr +from synapse.api.constants import EduTypes from synapse.api.errors import AuthError, ShadowBanError, SynapseError from synapse.appservice import ApplicationService from synapse.metrics.background_process_metrics import ( @@ -68,7 +69,7 @@ class FollowerTypingHandler: if hs.get_instance_name() not in hs.config.worker.writers.typing: hs.get_federation_registry().register_instances_for_edu( - "m.typing", + EduTypes.TYPING, hs.config.worker.writers.typing, ) @@ -143,7 +144,7 @@ class FollowerTypingHandler: logger.debug("sending typing update to %s", domain) self.federation.build_and_send_edu( destination=domain, - edu_type="m.typing", + edu_type=EduTypes.TYPING, content={ "room_id": member.room_id, "user_id": member.user_id, @@ -218,7 +219,9 @@ class TypingWriterHandler(FollowerTypingHandler): self.hs = hs - hs.get_federation_registry().register_edu_handler("m.typing", self._recv_edu) + hs.get_federation_registry().register_edu_handler( + EduTypes.TYPING, self._recv_edu + ) hs.get_distributor().observe("user_left_room", self.user_left_room) @@ -458,7 +461,7 @@ class TypingNotificationEventSource(EventSource[int, JsonDict]): def _make_event_for(self, room_id: str) -> JsonDict: typing = self.get_typing_handler()._room_typing[room_id] return { - "type": "m.typing", + "type": EduTypes.TYPING, "room_id": room_id, "content": {"user_ids": list(typing)}, } diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 901c47f756..776ed43f03 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -92,9 +92,6 @@ incoming_responses_counter = Counter( "synapse_http_matrixfederationclient_responses", "", ["method", "code"] ) -# a federation response can be rather large (eg a big state_ids is 50M or so), so we -# need a generous limit here. -MAX_RESPONSE_SIZE = 100 * 1024 * 1024 MAX_LONG_RETRIES = 10 MAX_SHORT_RETRIES = 3 @@ -116,6 +113,11 @@ class ByteParser(ByteWriteable, Generic[T], abc.ABC): the content type doesn't match we fail the request. """ + # a federation response can be rather large (eg a big state_ids is 50M or so), so we + # need a generous limit here. + MAX_RESPONSE_SIZE: int = 100 * 1024 * 1024 + """The largest response this parser will accept.""" + @abc.abstractmethod def finish(self) -> T: """Called when response has finished streaming and the parser should @@ -203,7 +205,6 @@ async def _handle_response( response: IResponse, start_ms: int, parser: ByteParser[T], - max_response_size: Optional[int] = None, ) -> T: """ Reads the body of a response with a timeout and sends it to a parser @@ -215,15 +216,12 @@ async def _handle_response( response: response to the request start_ms: Timestamp when request was made parser: The parser for the response - max_response_size: The maximum size to read from the response, if None - uses the default. Returns: The parsed response """ - if max_response_size is None: - max_response_size = MAX_RESPONSE_SIZE + max_response_size = parser.MAX_RESPONSE_SIZE finished = False try: @@ -242,7 +240,7 @@ async def _handle_response( "{%s} [%s] JSON response exceeded max size %i - %s %s", request.txn_id, request.destination, - MAX_RESPONSE_SIZE, + max_response_size, request.method, request.uri.decode("ascii"), ) @@ -783,7 +781,6 @@ class MatrixFederationHttpClient: backoff_on_404: bool = False, try_trailing_slash_on_400: bool = False, parser: Literal[None] = None, - max_response_size: Optional[int] = None, ) -> Union[JsonDict, list]: ... @@ -801,7 +798,6 @@ class MatrixFederationHttpClient: backoff_on_404: bool = False, try_trailing_slash_on_400: bool = False, parser: Optional[ByteParser[T]] = None, - max_response_size: Optional[int] = None, ) -> T: ... @@ -818,7 +814,6 @@ class MatrixFederationHttpClient: backoff_on_404: bool = False, try_trailing_slash_on_400: bool = False, parser: Optional[ByteParser] = None, - max_response_size: Optional[int] = None, ): """Sends the specified json data using PUT @@ -854,8 +849,6 @@ class MatrixFederationHttpClient: enabled. parser: The parser to use to decode the response. Defaults to parsing as JSON. - max_response_size: The maximum size to read from the response, if None - uses the default. Returns: Succeeds when we get a 2xx HTTP response. The @@ -906,7 +899,6 @@ class MatrixFederationHttpClient: response, start_ms, parser=parser, - max_response_size=max_response_size, ) return body @@ -995,7 +987,6 @@ class MatrixFederationHttpClient: ignore_backoff: bool = False, try_trailing_slash_on_400: bool = False, parser: Literal[None] = None, - max_response_size: Optional[int] = None, ) -> Union[JsonDict, list]: ... @@ -1010,7 +1001,6 @@ class MatrixFederationHttpClient: ignore_backoff: bool = ..., try_trailing_slash_on_400: bool = ..., parser: ByteParser[T] = ..., - max_response_size: Optional[int] = ..., ) -> T: ... @@ -1024,7 +1014,6 @@ class MatrixFederationHttpClient: ignore_backoff: bool = False, try_trailing_slash_on_400: bool = False, parser: Optional[ByteParser] = None, - max_response_size: Optional[int] = None, ): """GETs some json from the given host homeserver and path @@ -1054,9 +1043,6 @@ class MatrixFederationHttpClient: parser: The parser to use to decode the response. Defaults to parsing as JSON. - max_response_size: The maximum size to read from the response. If None, - uses the default. - Returns: Succeeds when we get a 2xx HTTP response. The result will be the decoded JSON body. @@ -1101,7 +1087,6 @@ class MatrixFederationHttpClient: response, start_ms, parser=parser, - max_response_size=max_response_size, ) return body diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index a02b5bf6bd..903ec40c86 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -168,9 +168,24 @@ import inspect import logging import re from functools import wraps -from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Pattern, Type +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Collection, + Dict, + Generator, + Iterable, + List, + Optional, + Pattern, + Type, + TypeVar, + Union, +) import attr +from typing_extensions import ParamSpec from twisted.internet import defer from twisted.web.http import Request @@ -256,7 +271,7 @@ try: def set_process(self, *args, **kwargs): return self._reporter.set_process(*args, **kwargs) - def report_span(self, span): + def report_span(self, span: "opentracing.Span") -> None: try: return self._reporter.report_span(span) except Exception: @@ -307,15 +322,19 @@ _homeserver_whitelist: Optional[Pattern[str]] = None Sentinel = object() -def only_if_tracing(func): +P = ParamSpec("P") +R = TypeVar("R") + + +def only_if_tracing(func: Callable[P, R]) -> Callable[P, Optional[R]]: """Executes the function only if we're tracing. Otherwise returns None.""" @wraps(func) - def _only_if_tracing_inner(*args, **kwargs): + def _only_if_tracing_inner(*args: P.args, **kwargs: P.kwargs) -> Optional[R]: if opentracing: return func(*args, **kwargs) else: - return + return None return _only_if_tracing_inner @@ -356,17 +375,10 @@ def ensure_active_span(message, ret=None): return ensure_active_span_inner_1 -@contextlib.contextmanager -def noop_context_manager(*args, **kwargs): - """Does exactly what it says on the tin""" - # TODO: replace with contextlib.nullcontext once we drop support for Python 3.6 - yield - - # Setup -def init_tracer(hs: "HomeServer"): +def init_tracer(hs: "HomeServer") -> None: """Set the whitelists and initialise the JaegerClient tracer""" global opentracing if not hs.config.tracing.opentracer_enabled: @@ -408,11 +420,11 @@ def init_tracer(hs: "HomeServer"): @only_if_tracing -def set_homeserver_whitelist(homeserver_whitelist): +def set_homeserver_whitelist(homeserver_whitelist: Iterable[str]) -> None: """Sets the homeserver whitelist Args: - homeserver_whitelist (Iterable[str]): regex of whitelisted homeservers + homeserver_whitelist: regexes specifying whitelisted homeservers """ global _homeserver_whitelist if homeserver_whitelist: @@ -423,15 +435,15 @@ def set_homeserver_whitelist(homeserver_whitelist): @only_if_tracing -def whitelisted_homeserver(destination): +def whitelisted_homeserver(destination: str) -> bool: """Checks if a destination matches the whitelist Args: - destination (str) + destination """ if _homeserver_whitelist: - return _homeserver_whitelist.match(destination) + return _homeserver_whitelist.match(destination) is not None return False @@ -457,11 +469,11 @@ def start_active_span( Args: See opentracing.tracer Returns: - scope (Scope) or noop_context_manager + scope (Scope) or contextlib.nullcontext """ if opentracing is None: - return noop_context_manager() # type: ignore[unreachable] + return contextlib.nullcontext() # type: ignore[unreachable] if tracer is None: # use the global tracer by default @@ -505,7 +517,7 @@ def start_active_span_follows_from( tracer: override the opentracing tracer. By default the global tracer is used. """ if opentracing is None: - return noop_context_manager() # type: ignore[unreachable] + return contextlib.nullcontext() # type: ignore[unreachable] references = [opentracing.follows_from(context) for context in contexts] scope = start_active_span( @@ -525,19 +537,19 @@ def start_active_span_follows_from( def start_active_span_from_edu( - edu_content, - operation_name, - references: Optional[list] = None, - tags=None, - start_time=None, - ignore_active_span=False, - finish_on_close=True, -): + edu_content: Dict[str, Any], + operation_name: str, + references: Optional[List["opentracing.Reference"]] = None, + tags: Optional[Dict] = None, + start_time: Optional[float] = None, + ignore_active_span: bool = False, + finish_on_close: bool = True, +) -> "opentracing.Scope": """ Extracts a span context from an edu and uses it to start a new active span Args: - edu_content (dict): and edu_content with a `context` field whose value is + edu_content: an edu_content with a `context` field whose value is canonical json for a dict which contains opentracing information. For the other args see opentracing.tracer @@ -545,7 +557,7 @@ def start_active_span_from_edu( references = references or [] if opentracing is None: - return noop_context_manager() # type: ignore[unreachable] + return contextlib.nullcontext() # type: ignore[unreachable] carrier = json_decoder.decode(edu_content.get("context", "{}")).get( "opentracing", {} @@ -578,27 +590,27 @@ def start_active_span_from_edu( # Opentracing setters for tags, logs, etc @only_if_tracing -def active_span(): +def active_span() -> Optional["opentracing.Span"]: """Get the currently active span, if any""" return opentracing.tracer.active_span @ensure_active_span("set a tag") -def set_tag(key, value): +def set_tag(key: str, value: Union[str, bool, int, float]) -> None: """Sets a tag on the active span""" assert opentracing.tracer.active_span is not None opentracing.tracer.active_span.set_tag(key, value) @ensure_active_span("log") -def log_kv(key_values, timestamp=None): +def log_kv(key_values: Dict[str, Any], timestamp: Optional[float] = None) -> None: """Log to the active span""" assert opentracing.tracer.active_span is not None opentracing.tracer.active_span.log_kv(key_values, timestamp) @ensure_active_span("set the traces operation name") -def set_operation_name(operation_name): +def set_operation_name(operation_name: str) -> None: """Sets the operation name of the active span""" assert opentracing.tracer.active_span is not None opentracing.tracer.active_span.set_operation_name(operation_name) @@ -624,7 +636,9 @@ def force_tracing(span=Sentinel) -> None: span.set_baggage_item(SynapseBaggage.FORCE_TRACING, "1") -def is_context_forced_tracing(span_context) -> bool: +def is_context_forced_tracing( + span_context: Optional["opentracing.SpanContext"], +) -> bool: """Check if sampling has been force for the given span context.""" if span_context is None: return False @@ -696,13 +710,13 @@ def inject_response_headers(response_headers: Headers) -> None: @ensure_active_span("get the active span context as a dict", ret={}) -def get_active_span_text_map(destination=None): +def get_active_span_text_map(destination: Optional[str] = None) -> Dict[str, str]: """ Gets a span context as a dict. This can be used instead of manually injecting a span into an empty carrier. Args: - destination (str): the name of the remote server. + destination: the name of the remote server. Returns: dict: the active span's context if opentracing is enabled, otherwise empty. @@ -721,7 +735,7 @@ def get_active_span_text_map(destination=None): @ensure_active_span("get the span context as a string.", ret={}) -def active_span_context_as_string(): +def active_span_context_as_string() -> str: """ Returns: The active span context encoded as a string. @@ -750,21 +764,21 @@ def span_context_from_request(request: Request) -> "Optional[opentracing.SpanCon @only_if_tracing -def span_context_from_string(carrier): +def span_context_from_string(carrier: str) -> Optional["opentracing.SpanContext"]: """ Returns: The active span context decoded from a string. """ - carrier = json_decoder.decode(carrier) - return opentracing.tracer.extract(opentracing.Format.TEXT_MAP, carrier) + payload: Dict[str, str] = json_decoder.decode(carrier) + return opentracing.tracer.extract(opentracing.Format.TEXT_MAP, payload) @only_if_tracing -def extract_text_map(carrier): +def extract_text_map(carrier: Dict[str, str]) -> Optional["opentracing.SpanContext"]: """ Wrapper method for opentracing's tracer.extract for TEXT_MAP. Args: - carrier (dict): a dict possibly containing a span context. + carrier: a dict possibly containing a span context. Returns: The active span context extracted from carrier. @@ -843,7 +857,7 @@ def trace(func=None, opname=None): return decorator -def tag_args(func): +def tag_args(func: Callable[P, R]) -> Callable[P, R]: """ Tags all of the args to the active span. """ @@ -852,11 +866,11 @@ def tag_args(func): return func @wraps(func) - def _tag_args_inner(*args, **kwargs): + def _tag_args_inner(*args: P.args, **kwargs: P.kwargs) -> R: argspec = inspect.getfullargspec(func) for i, arg in enumerate(argspec.args[1:]): - set_tag("ARG_" + arg, args[i]) - set_tag("args", args[len(argspec.args) :]) + set_tag("ARG_" + arg, args[i]) # type: ignore[index] + set_tag("args", args[len(argspec.args) :]) # type: ignore[index] set_tag("kwargs", kwargs) return func(*args, **kwargs) @@ -864,7 +878,9 @@ def tag_args(func): @contextlib.contextmanager -def trace_servlet(request: "SynapseRequest", extract_context: bool = False): +def trace_servlet( + request: "SynapseRequest", extract_context: bool = False +) -> Generator[None, None, None]: """Returns a context manager which traces a request. It starts a span with some servlet specific tags such as the request metrics name and request information. diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py index 298809742a..eef3462e10 100644 --- a/synapse/metrics/background_process_metrics.py +++ b/synapse/metrics/background_process_metrics.py @@ -14,6 +14,7 @@ import logging import threading +from contextlib import nullcontext from functools import wraps from types import TracebackType from typing import ( @@ -41,11 +42,7 @@ from synapse.logging.context import ( LoggingContext, PreserveLoggingContext, ) -from synapse.logging.opentracing import ( - SynapseTags, - noop_context_manager, - start_active_span, -) +from synapse.logging.opentracing import SynapseTags, start_active_span from synapse.metrics._types import Collector if TYPE_CHECKING: @@ -238,7 +235,7 @@ def run_as_background_process( f"bgproc.{desc}", tags={SynapseTags.REQUEST_ID: str(context)} ) else: - ctx = noop_context_manager() + ctx = nullcontext() with ctx: return await func(*args, **kwargs) except Exception: diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index c44e9da121..b7451fc870 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -1147,7 +1147,10 @@ class ModuleApi: ) async def sleep(self, seconds: float) -> None: - """Sleeps for the given number of seconds.""" + """Sleeps for the given number of seconds. + + Added in Synapse v1.49.0. + """ await self._clock.sleep(seconds) @@ -1427,6 +1430,28 @@ class ModuleApi: user_id, spec, {"actions": actions} ) + async def get_monthly_active_users_by_service( + self, start_timestamp: Optional[int] = None, end_timestamp: Optional[int] = None + ) -> List[Tuple[str, str]]: + """Generates list of monthly active users and their services. + Please see corresponding storage docstring for more details. + + Added in Synapse v1.61.0. + + Arguments: + start_timestamp: If specified, only include users that were first active + at or after this point + end_timestamp: If specified, only include users that were first active + at or before this point + + Returns: + A list of tuples (appservice_id, user_id) + + """ + return await self._store.get_monthly_active_users_by_service( + start_timestamp, end_timestamp + ) + class PublicRoomListManager: """Contains methods for adding to, removing from and querying whether a room diff --git a/synapse/notifier.py b/synapse/notifier.py index ba23257f54..1100434b3f 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -33,7 +33,7 @@ from prometheus_client import Counter from twisted.internet import defer -from synapse.api.constants import EventTypes, HistoryVisibility, Membership +from synapse.api.constants import EduTypes, EventTypes, HistoryVisibility, Membership from synapse.api.errors import AuthError from synapse.events import EventBase from synapse.handlers.presence import format_user_presence_state @@ -221,7 +221,7 @@ class Notifier: self.room_to_user_streams: Dict[str, Set[_NotifierUserStream]] = {} self.hs = hs - self.storage = hs.get_storage() + self._storage_controllers = hs.get_storage_controllers() self.event_sources = hs.get_event_sources() self.store = hs.get_datastores().main self.pending_new_room_events: List[_PendingRoomEventEntry] = [] @@ -623,7 +623,7 @@ class Notifier: if name == "room": new_events = await filter_events_for_client( - self.storage, + self._storage_controllers, user.to_string(), new_events, is_peeking=is_peeking, @@ -632,7 +632,7 @@ class Notifier: now = self.clock.time_msec() new_events[:] = [ { - "type": "m.presence", + "type": EduTypes.PRESENCE, "content": format_user_presence_state(event, now), } for event in new_events diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py index a17b35a605..819bc9e9b6 100644 --- a/synapse/push/baserules.py +++ b/synapse/push/baserules.py @@ -139,6 +139,7 @@ BASE_APPEND_CONTENT_RULES: List[Dict[str, Any]] = [ { "kind": "event_match", "key": "content.body", + # Match the localpart of the requester's MXID. "pattern_type": "user_localpart", } ], @@ -191,6 +192,7 @@ BASE_APPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [ "pattern": "invite", "_cache_key": "_invite_member", }, + # Match the requester's MXID. {"kind": "event_match", "key": "state_key", "pattern_type": "user_id"}, ], "actions": [ @@ -290,7 +292,7 @@ BASE_APPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [ "_cache_key": "_room_server_acl", } ], - "actions": ["dont_notify"], + "actions": [], }, ] @@ -351,6 +353,18 @@ BASE_APPEND_UNDERRIDE_RULES: List[Dict[str, Any]] = [ ], }, { + "rule_id": "global/underride/.org.matrix.msc3772.thread_reply", + "conditions": [ + { + "kind": "org.matrix.msc3772.relation_match", + "rel_type": "m.thread", + # Match the requester's MXID. + "sender_type": "user_id", + } + ], + "actions": ["notify", {"set_tweak": "highlight", "value": False}], + }, + { "rule_id": "global/underride/.m.rule.message", "conditions": [ { diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 4cc8a2ecca..7791b289e2 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -13,8 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import itertools import logging -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple, Union import attr from prometheus_client import Counter @@ -121,6 +122,9 @@ class BulkPushRuleEvaluator: resizable=False, ) + # Whether to support MSC3772 is supported. + self._relations_match_enabled = self.hs.config.experimental.msc3772_enabled + async def _get_rules_for_event( self, event: EventBase, context: EventContext ) -> Dict[str, List[Dict[str, Any]]]: @@ -149,12 +153,10 @@ class BulkPushRuleEvaluator: if event.type == "m.room.member" and event.content["membership"] == "invite": invited = event.state_key if invited and self.hs.is_mine_id(invited): - has_pusher = await self.store.user_has_pusher(invited) - if has_pusher: - rules_by_user = dict(rules_by_user) - rules_by_user[invited] = await self.store.get_push_rules_for_user( - invited - ) + rules_by_user = dict(rules_by_user) + rules_by_user[invited] = await self.store.get_push_rules_for_user( + invited + ) return rules_by_user @@ -192,6 +194,60 @@ class BulkPushRuleEvaluator: return pl_event.content if pl_event else {}, sender_level + async def _get_mutual_relations( + self, event: EventBase, rules: Iterable[Dict[str, Any]] + ) -> Dict[str, Set[Tuple[str, str]]]: + """ + Fetch event metadata for events which related to the same event as the given event. + + If the given event has no relation information, returns an empty dictionary. + + Args: + event_id: The event ID which is targeted by relations. + rules: The push rules which will be processed for this event. + + Returns: + A dictionary of relation type to: + A set of tuples of: + The sender + The event type + """ + + # If the experimental feature is not enabled, skip fetching relations. + if not self._relations_match_enabled: + return {} + + # If the event does not have a relation, then cannot have any mutual + # relations. + relation = relation_from_event(event) + if not relation: + return {} + + # Pre-filter to figure out which relation types are interesting. + rel_types = set() + for rule in rules: + # Skip disabled rules. + if "enabled" in rule and not rule["enabled"]: + continue + + for condition in rule["conditions"]: + if condition["kind"] != "org.matrix.msc3772.relation_match": + continue + + # rel_type is required. + rel_type = condition.get("rel_type") + if rel_type: + rel_types.add(rel_type) + + # If no valid rules were found, no mutual relations. + if not rel_types: + return {} + + # If any valid rules were found, fetch the mutual relations. + return await self.store.get_mutual_event_relations( + relation.parent_id, rel_types + ) + @measure_func("action_for_event_by_user") async def action_for_event_by_user( self, event: EventBase, context: EventContext @@ -216,8 +272,17 @@ class BulkPushRuleEvaluator: sender_power_level, ) = await self._get_power_levels_and_sender_level(event, context) + relations = await self._get_mutual_relations( + event, itertools.chain(*rules_by_user.values()) + ) + evaluator = PushRuleEvaluatorForEvent( - event, len(room_members), sender_power_level, power_levels + event, + len(room_members), + sender_power_level, + power_levels, + relations, + self._relations_match_enabled, ) # If the event is not a state event check if any users ignore the sender. diff --git a/synapse/push/clientformat.py b/synapse/push/clientformat.py index 63b22d50ae..5117ef6854 100644 --- a/synapse/push/clientformat.py +++ b/synapse/push/clientformat.py @@ -48,6 +48,10 @@ def format_push_rules_for_user( elif pattern_type == "user_localpart": c["pattern"] = user.localpart + sender_type = c.pop("sender_type", None) + if sender_type == "user_id": + c["sender"] = user.to_string() + rulearray = rules["global"][template_name] template_rule = _rule_to_template(r) diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index d5603596c0..e96fb45e9f 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -65,7 +65,7 @@ class HttpPusher(Pusher): def __init__(self, hs: "HomeServer", pusher_config: PusherConfig): super().__init__(hs, pusher_config) - self.storage = self.hs.get_storage() + self._storage_controllers = self.hs.get_storage_controllers() self.app_display_name = pusher_config.app_display_name self.device_display_name = pusher_config.device_display_name self.pushkey_ts = pusher_config.ts @@ -343,7 +343,9 @@ class HttpPusher(Pusher): } return d - ctx = await push_tools.get_context_for_event(self.storage, event, self.user_id) + ctx = await push_tools.get_context_for_event( + self._storage_controllers, event, self.user_id + ) d = { "notification": { diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index 5ccdd88364..63aefd07f5 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -114,10 +114,10 @@ class Mailer: self.send_email_handler = hs.get_send_email_handler() self.store = self.hs.get_datastores().main - self.state_store = self.hs.get_storage().state + self._state_storage_controller = self.hs.get_storage_controllers().state self.macaroon_gen = self.hs.get_macaroon_generator() self.state_handler = self.hs.get_state_handler() - self.storage = hs.get_storage() + self._storage_controllers = hs.get_storage_controllers() self.app_name = app_name self.email_subjects: EmailSubjectConfig = hs.config.email.email_subjects @@ -456,7 +456,7 @@ class Mailer: } the_events = await filter_events_for_client( - self.storage, user_id, results.events_before + self._storage_controllers, user_id, results.events_before ) the_events.append(notif_event) @@ -494,7 +494,7 @@ class Mailer: ) else: # Attempt to check the historical state for the room. - historical_state = await self.state_store.get_state_for_event( + historical_state = await self._state_storage_controller.get_state_for_event( event.event_id, StateFilter.from_types((type_state_key,)) ) sender_state_event = historical_state.get(type_state_key) @@ -767,8 +767,10 @@ class Mailer: member_event_ids.append(sender_state_event_id) else: # Attempt to check the historical state for the room. - historical_state = await self.state_store.get_state_for_event( - event_id, StateFilter.from_types((type_state_key,)) + historical_state = ( + await self._state_storage_controller.get_state_for_event( + event_id, StateFilter.from_types((type_state_key,)) + ) ) sender_state_event = historical_state.get(type_state_key) if sender_state_event: diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index 54db6b5612..2e8a017add 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -15,7 +15,7 @@ import logging import re -from typing import Any, Dict, List, Mapping, Optional, Pattern, Tuple, Union +from typing import Any, Dict, List, Mapping, Optional, Pattern, Set, Tuple, Union from matrix_common.regex import glob_to_regex, to_word_pattern @@ -120,11 +120,15 @@ class PushRuleEvaluatorForEvent: room_member_count: int, sender_power_level: int, power_levels: Dict[str, Union[int, Dict[str, int]]], + relations: Dict[str, Set[Tuple[str, str]]], + relations_match_enabled: bool, ): self._event = event self._room_member_count = room_member_count self._sender_power_level = sender_power_level self._power_levels = power_levels + self._relations = relations + self._relations_match_enabled = relations_match_enabled # Maps strings of e.g. 'content.body' -> event["content"]["body"] self._value_cache = _flatten_dict(event) @@ -188,7 +192,16 @@ class PushRuleEvaluatorForEvent: return _sender_notification_permission( self._event, condition, self._sender_power_level, self._power_levels ) + elif ( + condition["kind"] == "org.matrix.msc3772.relation_match" + and self._relations_match_enabled + ): + return self._relation_match(condition, user_id) else: + # XXX This looks incorrect -- we have reached an unknown condition + # kind and are unconditionally returning that it matches. Note + # that it seems possible to provide a condition to the /pushrules + # endpoint with an unknown kind, see _rule_tuple_from_request_object. return True def _event_match(self, condition: dict, user_id: str) -> bool: @@ -256,6 +269,41 @@ class PushRuleEvaluatorForEvent: return bool(r.search(body)) + def _relation_match(self, condition: dict, user_id: str) -> bool: + """ + Check an "relation_match" push rule condition. + + Args: + condition: The "event_match" push rule condition to match. + user_id: The user's MXID. + + Returns: + True if the condition matches the event, False otherwise. + """ + rel_type = condition.get("rel_type") + if not rel_type: + logger.warning("relation_match condition missing rel_type") + return False + + sender_pattern = condition.get("sender") + if sender_pattern is None: + sender_type = condition.get("sender_type") + if sender_type == "user_id": + sender_pattern = user_id + type_pattern = condition.get("type") + + # If any other relations matches, return True. + for sender, event_type in self._relations.get(rel_type, ()): + if sender_pattern and not _glob_matches(sender_pattern, sender): + continue + if type_pattern and not _glob_matches(type_pattern, event_type): + continue + # All values must have matched. + return True + + # No relations matched. + return False + # Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches regex_cache: LruCache[Tuple[str, bool, bool], Pattern] = LruCache( diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index a1bf5b20dd..8397229ccb 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -16,7 +16,7 @@ from typing import Dict from synapse.api.constants import ReceiptTypes from synapse.events import EventBase from synapse.push.presentable_names import calculate_room_name, name_from_member_event -from synapse.storage import Storage +from synapse.storage.controllers import StorageControllers from synapse.storage.databases.main import DataStore @@ -52,7 +52,7 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) - async def get_context_for_event( - storage: Storage, ev: EventBase, user_id: str + storage: StorageControllers, ev: EventBase, user_id: str ) -> Dict[str, str]: ctx = {} diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index 3e7300b4a1..eed29cd597 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -69,7 +69,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): super().__init__(hs) self.store = hs.get_datastores().main - self.storage = hs.get_storage() + self._storage_controllers = hs.get_storage_controllers() self.clock = hs.get_clock() self.federation_event_handler = hs.get_federation_event_handler() @@ -133,7 +133,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): event.internal_metadata.outlier = event_payload["outlier"] context = EventContext.deserialize( - self.storage, event_payload["context"] + self._storage_controllers, event_payload["context"] ) event_and_contexts.append((event, context)) diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py index ce78176836..c2b2588ea5 100644 --- a/synapse/replication/http/send_event.py +++ b/synapse/replication/http/send_event.py @@ -70,7 +70,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): self.event_creation_handler = hs.get_event_creation_handler() self.store = hs.get_datastores().main - self.storage = hs.get_storage() + self._storage_controllers = hs.get_storage_controllers() self.clock = hs.get_clock() @staticmethod @@ -127,7 +127,9 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): event.internal_metadata.outlier = content["outlier"] requester = Requester.deserialize(self.store, content["requester"]) - context = EventContext.deserialize(self.storage, content["context"]) + context = EventContext.deserialize( + self._storage_controllers, content["context"] + ) ratelimit = content["ratelimit"] extra_users = [UserID.from_string(u) for u in content["extra_users"]] diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index 57c4773edc..b712215112 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -26,7 +26,6 @@ from synapse.rest.client import ( directory, events, filter, - groups, initial_sync, keys, knock, @@ -118,8 +117,6 @@ class ClientRestResource(JsonResource): thirdparty.register_servlets(hs, client_resource) sendtodevice.register_servlets(hs, client_resource) user_directory.register_servlets(hs, client_resource) - if hs.config.experimental.groups_enabled: - groups.register_servlets(hs, client_resource) room_upgrade_rest_servlet.register_servlets(hs, client_resource) room_batch.register_servlets(hs, client_resource) capabilities.register_servlets(hs, client_resource) diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index cb4d55c89d..1aa08f8d95 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -47,7 +47,6 @@ from synapse.rest.admin.federation import ( DestinationRestServlet, ListDestinationsRestServlet, ) -from synapse.rest.admin.groups import DeleteGroupAdminRestServlet from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo from synapse.rest.admin.registration_tokens import ( ListRegistrationTokensRestServlet, @@ -293,8 +292,6 @@ def register_servlets_for_client_rest_resource( ResetPasswordRestServlet(hs).register(http_server) SearchUsersRestServlet(hs).register(http_server) UserRegisterServlet(hs).register(http_server) - if hs.config.experimental.groups_enabled: - DeleteGroupAdminRestServlet(hs).register(http_server) AccountValidityRenewServlet(hs).register(http_server) # Load the media repo ones if we're using them. Otherwise load the servlets which diff --git a/synapse/rest/admin/groups.py b/synapse/rest/admin/groups.py deleted file mode 100644 index cd697e180e..0000000000 --- a/synapse/rest/admin/groups.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright 2019 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging -from http import HTTPStatus -from typing import TYPE_CHECKING, Tuple - -from synapse.api.errors import SynapseError -from synapse.http.servlet import RestServlet -from synapse.http.site import SynapseRequest -from synapse.rest.admin._base import admin_patterns, assert_user_is_admin -from synapse.types import JsonDict - -if TYPE_CHECKING: - from synapse.server import HomeServer - -logger = logging.getLogger(__name__) - - -class DeleteGroupAdminRestServlet(RestServlet): - """Allows deleting of local groups""" - - PATTERNS = admin_patterns("/delete_group/(?P<group_id>[^/]*)$") - - def __init__(self, hs: "HomeServer"): - self.group_server = hs.get_groups_server_handler() - self.is_mine_id = hs.is_mine_id - self.auth = hs.get_auth() - - async def on_POST( - self, request: SynapseRequest, group_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - await assert_user_is_admin(self.auth, requester.user) - - if not self.is_mine_id(group_id): - raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only delete local groups") - - await self.group_server.delete_group(group_id, requester.user.to_string()) - return HTTPStatus.OK, {} diff --git a/synapse/rest/client/groups.py b/synapse/rest/client/groups.py deleted file mode 100644 index 7e1149c7f4..0000000000 --- a/synapse/rest/client/groups.py +++ /dev/null @@ -1,962 +0,0 @@ -# Copyright 2017 Vector Creations Ltd -# Copyright 2018 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from functools import wraps -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple - -from twisted.web.server import Request - -from synapse.api.constants import ( - MAX_GROUP_CATEGORYID_LENGTH, - MAX_GROUP_ROLEID_LENGTH, - MAX_GROUPID_LENGTH, -) -from synapse.api.errors import Codes, SynapseError -from synapse.handlers.groups_local import GroupsLocalHandler -from synapse.http.server import HttpServer -from synapse.http.servlet import ( - RestServlet, - assert_params_in_dict, - parse_json_object_from_request, -) -from synapse.http.site import SynapseRequest -from synapse.types import GroupID, JsonDict - -from ._base import client_patterns - -if TYPE_CHECKING: - from synapse.server import HomeServer - -logger = logging.getLogger(__name__) - - -def _validate_group_id( - f: Callable[..., Awaitable[Tuple[int, JsonDict]]] -) -> Callable[..., Awaitable[Tuple[int, JsonDict]]]: - """Wrapper to validate the form of the group ID. - - Can be applied to any on_FOO methods that accepts a group ID as a URL parameter. - """ - - @wraps(f) - def wrapper( - self: RestServlet, request: Request, group_id: str, *args: Any, **kwargs: Any - ) -> Awaitable[Tuple[int, JsonDict]]: - if not GroupID.is_valid(group_id): - raise SynapseError(400, "%s is not a legal group ID" % (group_id,)) - - return f(self, request, group_id, *args, **kwargs) - - return wrapper - - -class GroupServlet(RestServlet): - """Get the group profile""" - - PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/profile$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_GET( - self, request: SynapseRequest, group_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request, allow_guest=True) - requester_user_id = requester.user.to_string() - - group_description = await self.groups_handler.get_group_profile( - group_id, requester_user_id - ) - - return 200, group_description - - @_validate_group_id - async def on_POST( - self, request: SynapseRequest, group_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - content = parse_json_object_from_request(request) - assert_params_in_dict( - content, ("name", "avatar_url", "short_description", "long_description") - ) - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot create group profiles." - await self.groups_handler.update_group_profile( - group_id, requester_user_id, content - ) - - return 200, {} - - -class GroupSummaryServlet(RestServlet): - """Get the full group summary""" - - PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/summary$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_GET( - self, request: SynapseRequest, group_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request, allow_guest=True) - requester_user_id = requester.user.to_string() - - get_group_summary = await self.groups_handler.get_group_summary( - group_id, requester_user_id - ) - - return 200, get_group_summary - - -class GroupSummaryRoomsCatServlet(RestServlet): - """Update/delete a rooms entry in the summary. - - Matches both: - - /groups/:group/summary/rooms/:room_id - - /groups/:group/summary/categories/:category/rooms/:room_id - """ - - PATTERNS = client_patterns( - "/groups/(?P<group_id>[^/]*)/summary" - "(/categories/(?P<category_id>[^/]+))?" - "/rooms/(?P<room_id>[^/]*)$" - ) - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_PUT( - self, - request: SynapseRequest, - group_id: str, - category_id: Optional[str], - room_id: str, - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - if category_id == "": - raise SynapseError(400, "category_id cannot be empty", Codes.INVALID_PARAM) - - if category_id and len(category_id) > MAX_GROUP_CATEGORYID_LENGTH: - raise SynapseError( - 400, - "category_id may not be longer than %s characters" - % (MAX_GROUP_CATEGORYID_LENGTH,), - Codes.INVALID_PARAM, - ) - - content = parse_json_object_from_request(request) - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot modify group summaries." - resp = await self.groups_handler.update_group_summary_room( - group_id, - requester_user_id, - room_id=room_id, - category_id=category_id, - content=content, - ) - - return 200, resp - - @_validate_group_id - async def on_DELETE( - self, request: SynapseRequest, group_id: str, category_id: str, room_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot modify group profiles." - resp = await self.groups_handler.delete_group_summary_room( - group_id, requester_user_id, room_id=room_id, category_id=category_id - ) - - return 200, resp - - -class GroupCategoryServlet(RestServlet): - """Get/add/update/delete a group category""" - - PATTERNS = client_patterns( - "/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)$" - ) - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_GET( - self, request: SynapseRequest, group_id: str, category_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request, allow_guest=True) - requester_user_id = requester.user.to_string() - - category = await self.groups_handler.get_group_category( - group_id, requester_user_id, category_id=category_id - ) - - return 200, category - - @_validate_group_id - async def on_PUT( - self, request: SynapseRequest, group_id: str, category_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - if not category_id: - raise SynapseError(400, "category_id cannot be empty", Codes.INVALID_PARAM) - - if len(category_id) > MAX_GROUP_CATEGORYID_LENGTH: - raise SynapseError( - 400, - "category_id may not be longer than %s characters" - % (MAX_GROUP_CATEGORYID_LENGTH,), - Codes.INVALID_PARAM, - ) - - content = parse_json_object_from_request(request) - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot modify group categories." - resp = await self.groups_handler.update_group_category( - group_id, requester_user_id, category_id=category_id, content=content - ) - - return 200, resp - - @_validate_group_id - async def on_DELETE( - self, request: SynapseRequest, group_id: str, category_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot modify group categories." - resp = await self.groups_handler.delete_group_category( - group_id, requester_user_id, category_id=category_id - ) - - return 200, resp - - -class GroupCategoriesServlet(RestServlet): - """Get all group categories""" - - PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/categories/$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_GET( - self, request: SynapseRequest, group_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request, allow_guest=True) - requester_user_id = requester.user.to_string() - - category = await self.groups_handler.get_group_categories( - group_id, requester_user_id - ) - - return 200, category - - -class GroupRoleServlet(RestServlet): - """Get/add/update/delete a group role""" - - PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_GET( - self, request: SynapseRequest, group_id: str, role_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request, allow_guest=True) - requester_user_id = requester.user.to_string() - - category = await self.groups_handler.get_group_role( - group_id, requester_user_id, role_id=role_id - ) - - return 200, category - - @_validate_group_id - async def on_PUT( - self, request: SynapseRequest, group_id: str, role_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - if not role_id: - raise SynapseError(400, "role_id cannot be empty", Codes.INVALID_PARAM) - - if len(role_id) > MAX_GROUP_ROLEID_LENGTH: - raise SynapseError( - 400, - "role_id may not be longer than %s characters" - % (MAX_GROUP_ROLEID_LENGTH,), - Codes.INVALID_PARAM, - ) - - content = parse_json_object_from_request(request) - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot modify group roles." - resp = await self.groups_handler.update_group_role( - group_id, requester_user_id, role_id=role_id, content=content - ) - - return 200, resp - - @_validate_group_id - async def on_DELETE( - self, request: SynapseRequest, group_id: str, role_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot modify group roles." - resp = await self.groups_handler.delete_group_role( - group_id, requester_user_id, role_id=role_id - ) - - return 200, resp - - -class GroupRolesServlet(RestServlet): - """Get all group roles""" - - PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/roles/$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_GET( - self, request: SynapseRequest, group_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request, allow_guest=True) - requester_user_id = requester.user.to_string() - - category = await self.groups_handler.get_group_roles( - group_id, requester_user_id - ) - - return 200, category - - -class GroupSummaryUsersRoleServlet(RestServlet): - """Update/delete a user's entry in the summary. - - Matches both: - - /groups/:group/summary/users/:room_id - - /groups/:group/summary/roles/:role/users/:user_id - """ - - PATTERNS = client_patterns( - "/groups/(?P<group_id>[^/]*)/summary" - "(/roles/(?P<role_id>[^/]+))?" - "/users/(?P<user_id>[^/]*)$" - ) - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_PUT( - self, - request: SynapseRequest, - group_id: str, - role_id: Optional[str], - user_id: str, - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - if role_id == "": - raise SynapseError(400, "role_id cannot be empty", Codes.INVALID_PARAM) - - if role_id and len(role_id) > MAX_GROUP_ROLEID_LENGTH: - raise SynapseError( - 400, - "role_id may not be longer than %s characters" - % (MAX_GROUP_ROLEID_LENGTH,), - Codes.INVALID_PARAM, - ) - - content = parse_json_object_from_request(request) - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot modify group summaries." - resp = await self.groups_handler.update_group_summary_user( - group_id, - requester_user_id, - user_id=user_id, - role_id=role_id, - content=content, - ) - - return 200, resp - - @_validate_group_id - async def on_DELETE( - self, request: SynapseRequest, group_id: str, role_id: str, user_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot modify group summaries." - resp = await self.groups_handler.delete_group_summary_user( - group_id, requester_user_id, user_id=user_id, role_id=role_id - ) - - return 200, resp - - -class GroupRoomServlet(RestServlet): - """Get all rooms in a group""" - - PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/rooms$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_GET( - self, request: SynapseRequest, group_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request, allow_guest=True) - requester_user_id = requester.user.to_string() - - result = await self.groups_handler.get_rooms_in_group( - group_id, requester_user_id - ) - - return 200, result - - -class GroupUsersServlet(RestServlet): - """Get all users in a group""" - - PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/users$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_GET( - self, request: SynapseRequest, group_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request, allow_guest=True) - requester_user_id = requester.user.to_string() - - result = await self.groups_handler.get_users_in_group( - group_id, requester_user_id - ) - - return 200, result - - -class GroupInvitedUsersServlet(RestServlet): - """Get users invited to a group""" - - PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/invited_users$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_GET( - self, request: SynapseRequest, group_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - result = await self.groups_handler.get_invited_users_in_group( - group_id, requester_user_id - ) - - return 200, result - - -class GroupSettingJoinPolicyServlet(RestServlet): - """Set group join policy""" - - PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/settings/m.join_policy$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_PUT( - self, request: SynapseRequest, group_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - content = parse_json_object_from_request(request) - - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot modify group join policy." - result = await self.groups_handler.set_group_join_policy( - group_id, requester_user_id, content - ) - - return 200, result - - -class GroupCreateServlet(RestServlet): - """Create a group""" - - PATTERNS = client_patterns("/create_group$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - self.server_name = hs.hostname - - async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - # TODO: Create group on remote server - content = parse_json_object_from_request(request) - localpart = content.pop("localpart") - group_id = GroupID(localpart, self.server_name).to_string() - - if not localpart: - raise SynapseError(400, "Group ID cannot be empty", Codes.INVALID_PARAM) - - if len(group_id) > MAX_GROUPID_LENGTH: - raise SynapseError( - 400, - "Group ID may not be longer than %s characters" % (MAX_GROUPID_LENGTH,), - Codes.INVALID_PARAM, - ) - - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot create groups." - result = await self.groups_handler.create_group( - group_id, requester_user_id, content - ) - - return 200, result - - -class GroupAdminRoomsServlet(RestServlet): - """Add a room to the group""" - - PATTERNS = client_patterns( - "/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)$" - ) - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_PUT( - self, request: SynapseRequest, group_id: str, room_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - content = parse_json_object_from_request(request) - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot modify rooms in a group." - result = await self.groups_handler.add_room_to_group( - group_id, requester_user_id, room_id, content - ) - - return 200, result - - @_validate_group_id - async def on_DELETE( - self, request: SynapseRequest, group_id: str, room_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot modify group categories." - result = await self.groups_handler.remove_room_from_group( - group_id, requester_user_id, room_id - ) - - return 200, result - - -class GroupAdminRoomsConfigServlet(RestServlet): - """Update the config of a room in a group""" - - PATTERNS = client_patterns( - "/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)" - "/config/(?P<config_key>[^/]*)$" - ) - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_PUT( - self, request: SynapseRequest, group_id: str, room_id: str, config_key: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - content = parse_json_object_from_request(request) - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot modify group categories." - result = await self.groups_handler.update_room_in_group( - group_id, requester_user_id, room_id, config_key, content - ) - - return 200, result - - -class GroupAdminUsersInviteServlet(RestServlet): - """Invite a user to the group""" - - PATTERNS = client_patterns( - "/groups/(?P<group_id>[^/]*)/admin/users/invite/(?P<user_id>[^/]*)$" - ) - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - self.store = hs.get_datastores().main - self.is_mine_id = hs.is_mine_id - - @_validate_group_id - async def on_PUT( - self, request: SynapseRequest, group_id: str, user_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - content = parse_json_object_from_request(request) - config = content.get("config", {}) - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot invite users to a group." - result = await self.groups_handler.invite( - group_id, user_id, requester_user_id, config - ) - - return 200, result - - -class GroupAdminUsersKickServlet(RestServlet): - """Kick a user from the group""" - - PATTERNS = client_patterns( - "/groups/(?P<group_id>[^/]*)/admin/users/remove/(?P<user_id>[^/]*)$" - ) - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_PUT( - self, request: SynapseRequest, group_id: str, user_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - content = parse_json_object_from_request(request) - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot kick users from a group." - result = await self.groups_handler.remove_user_from_group( - group_id, user_id, requester_user_id, content - ) - - return 200, result - - -class GroupSelfLeaveServlet(RestServlet): - """Leave a joined group""" - - PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/leave$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_PUT( - self, request: SynapseRequest, group_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - content = parse_json_object_from_request(request) - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot leave a group for a users." - result = await self.groups_handler.remove_user_from_group( - group_id, requester_user_id, requester_user_id, content - ) - - return 200, result - - -class GroupSelfJoinServlet(RestServlet): - """Attempt to join a group, or knock""" - - PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/join$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_PUT( - self, request: SynapseRequest, group_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - content = parse_json_object_from_request(request) - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot join a user to a group." - result = await self.groups_handler.join_group( - group_id, requester_user_id, content - ) - - return 200, result - - -class GroupSelfAcceptInviteServlet(RestServlet): - """Accept a group invite""" - - PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/accept_invite$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - @_validate_group_id - async def on_PUT( - self, request: SynapseRequest, group_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - content = parse_json_object_from_request(request) - assert isinstance( - self.groups_handler, GroupsLocalHandler - ), "Workers cannot accept an invite to a group." - result = await self.groups_handler.accept_invite( - group_id, requester_user_id, content - ) - - return 200, result - - -class GroupSelfUpdatePublicityServlet(RestServlet): - """Update whether we publicise a users membership of a group""" - - PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/update_publicity$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.store = hs.get_datastores().main - - @_validate_group_id - async def on_PUT( - self, request: SynapseRequest, group_id: str - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) - requester_user_id = requester.user.to_string() - - content = parse_json_object_from_request(request) - publicise = content["publicise"] - await self.store.update_group_publicity(group_id, requester_user_id, publicise) - - return 200, {} - - -class PublicisedGroupsForUserServlet(RestServlet): - """Get the list of groups a user is advertising""" - - PATTERNS = client_patterns("/publicised_groups/(?P<user_id>[^/]*)$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.store = hs.get_datastores().main - self.groups_handler = hs.get_groups_local_handler() - - async def on_GET( - self, request: SynapseRequest, user_id: str - ) -> Tuple[int, JsonDict]: - await self.auth.get_user_by_req(request, allow_guest=True) - - result = await self.groups_handler.get_publicised_groups_for_user(user_id) - - return 200, result - - -class PublicisedGroupsForUsersServlet(RestServlet): - """Get the list of groups a user is advertising""" - - PATTERNS = client_patterns("/publicised_groups$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.store = hs.get_datastores().main - self.groups_handler = hs.get_groups_local_handler() - - async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - await self.auth.get_user_by_req(request, allow_guest=True) - - content = parse_json_object_from_request(request) - user_ids = content["user_ids"] - - result = await self.groups_handler.bulk_get_publicised_groups(user_ids) - - return 200, result - - -class GroupsForUserServlet(RestServlet): - """Get all groups the logged in user is joined to""" - - PATTERNS = client_patterns("/joined_groups$") - - def __init__(self, hs: "HomeServer"): - super().__init__() - self.auth = hs.get_auth() - self.clock = hs.get_clock() - self.groups_handler = hs.get_groups_local_handler() - - async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request, allow_guest=True) - requester_user_id = requester.user.to_string() - - result = await self.groups_handler.get_joined_groups(requester_user_id) - - return 200, result - - -def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: - GroupServlet(hs).register(http_server) - GroupSummaryServlet(hs).register(http_server) - GroupInvitedUsersServlet(hs).register(http_server) - GroupUsersServlet(hs).register(http_server) - GroupRoomServlet(hs).register(http_server) - GroupSettingJoinPolicyServlet(hs).register(http_server) - GroupCreateServlet(hs).register(http_server) - GroupAdminRoomsServlet(hs).register(http_server) - GroupAdminRoomsConfigServlet(hs).register(http_server) - GroupAdminUsersInviteServlet(hs).register(http_server) - GroupAdminUsersKickServlet(hs).register(http_server) - GroupSelfLeaveServlet(hs).register(http_server) - GroupSelfJoinServlet(hs).register(http_server) - GroupSelfAcceptInviteServlet(hs).register(http_server) - GroupsForUserServlet(hs).register(http_server) - GroupCategoryServlet(hs).register(http_server) - GroupCategoriesServlet(hs).register(http_server) - GroupSummaryRoomsCatServlet(hs).register(http_server) - GroupRoleServlet(hs).register(http_server) - GroupRolesServlet(hs).register(http_server) - GroupSelfUpdatePublicityServlet(hs).register(http_server) - GroupSummaryUsersRoleServlet(hs).register(http_server) - PublicisedGroupsForUserServlet(hs).register(http_server) - PublicisedGroupsForUsersServlet(hs).register(http_server) diff --git a/synapse/rest/client/mutual_rooms.py b/synapse/rest/client/mutual_rooms.py index 27bfaf0b29..38ef4e459f 100644 --- a/synapse/rest/client/mutual_rooms.py +++ b/synapse/rest/client/mutual_rooms.py @@ -42,21 +42,10 @@ class UserMutualRoomsServlet(RestServlet): super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastores().main - self.user_directory_search_enabled = ( - hs.config.userdirectory.user_directory_search_enabled - ) async def on_GET( self, request: SynapseRequest, user_id: str ) -> Tuple[int, JsonDict]: - - if not self.user_directory_search_enabled: - raise SynapseError( - code=400, - msg="User directory searching is disabled. Cannot determine shared rooms.", - errcode=Codes.UNKNOWN, - ) - UserID.from_string(user_id) requester = await self.auth.get_user_by_req(request) @@ -67,8 +56,8 @@ class UserMutualRoomsServlet(RestServlet): errcode=Codes.FORBIDDEN, ) - rooms = await self.store.get_mutual_rooms_for_users( - requester.user.to_string(), user_id + rooms = await self.store.get_mutual_rooms_between_users( + frozenset((requester.user.to_string(), user_id)) ) return 200, {"joined": list(rooms)} diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 5a2361a2e6..7a5ce8ad0e 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -1193,12 +1193,7 @@ class TimestampLookupRestServlet(RestServlet): class RoomHierarchyRestServlet(RestServlet): - PATTERNS = ( - re.compile( - "^/_matrix/client/(v1|unstable/org.matrix.msc2946)" - "/rooms/(?P<room_id>[^/]*)/hierarchy$" - ), - ) + PATTERNS = (re.compile("^/_matrix/client/v1/rooms/(?P<room_id>[^/]*)/hierarchy$"),) def __init__(self, hs: "HomeServer"): super().__init__() diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index e8772f86e7..8bbf35148d 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -16,7 +16,7 @@ import logging from collections import defaultdict from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union -from synapse.api.constants import Membership, PresenceState +from synapse.api.constants import EduTypes, Membership, PresenceState from synapse.api.errors import Codes, StoreError, SynapseError from synapse.api.filtering import FilterCollection from synapse.api.presence import UserPresenceState @@ -298,14 +298,6 @@ class SyncRestServlet(RestServlet): if archived: response["rooms"][Membership.LEAVE] = archived - if sync_result.groups is not None: - if sync_result.groups.join: - response["groups"][Membership.JOIN] = sync_result.groups.join - if sync_result.groups.invite: - response["groups"][Membership.INVITE] = sync_result.groups.invite - if sync_result.groups.leave: - response["groups"][Membership.LEAVE] = sync_result.groups.leave - return response @staticmethod @@ -313,7 +305,7 @@ class SyncRestServlet(RestServlet): return { "events": [ { - "type": "m.presence", + "type": EduTypes.PRESENCE, "sender": event.user_id, "content": format_user_presence_state( event, time_now, include_user_id=False diff --git a/synapse/rest/media/v1/preview_html.py b/synapse/rest/media/v1/preview_html.py index 0358c68a64..13ec7ab533 100644 --- a/synapse/rest/media/v1/preview_html.py +++ b/synapse/rest/media/v1/preview_html.py @@ -289,7 +289,7 @@ def parse_html_description(tree: "etree.Element") -> Optional[str]: def _iterate_over_text( - tree: "etree.Element", *tags_to_ignore: Iterable[Union[str, "etree.Comment"]] + tree: "etree.Element", *tags_to_ignore: Union[str, "etree.Comment"] ) -> Generator[str, None, None]: """Iterate over the tree returning text nodes in a depth first fashion, skipping text nodes inside certain tags. diff --git a/synapse/server.py b/synapse/server.py index ee60cce8eb..a66ec228db 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -21,17 +21,7 @@ import abc import functools import logging -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - List, - Optional, - TypeVar, - Union, - cast, -) +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, TypeVar, cast from twisted.internet.interfaces import IOpenSSLContextFactory from twisted.internet.tcp import Port @@ -60,8 +50,6 @@ from synapse.federation.federation_server import ( from synapse.federation.send_queue import FederationRemoteSendQueue from synapse.federation.sender import AbstractFederationSender, FederationSender from synapse.federation.transport.client import TransportLayerClient -from synapse.groups.attestations import GroupAttestationSigning, GroupAttestionRenewer -from synapse.groups.groups_server import GroupsServerHandler, GroupsServerWorkerHandler from synapse.handlers.account import AccountHandler from synapse.handlers.account_data import AccountDataHandler from synapse.handlers.account_validity import AccountValidityHandler @@ -79,7 +67,6 @@ from synapse.handlers.event_auth import EventAuthHandler from synapse.handlers.events import EventHandler, EventStreamHandler from synapse.handlers.federation import FederationHandler from synapse.handlers.federation_event import FederationEventHandler -from synapse.handlers.groups_local import GroupsLocalHandler, GroupsLocalWorkerHandler from synapse.handlers.identity import IdentityHandler from synapse.handlers.initial_sync import InitialSyncHandler from synapse.handlers.message import EventCreationHandler, MessageHandler @@ -136,7 +123,8 @@ from synapse.server_notices.worker_server_notices_sender import ( WorkerServerNoticesSender, ) from synapse.state import StateHandler, StateResolutionHandler -from synapse.storage import Databases, Storage +from synapse.storage import Databases +from synapse.storage.controllers import StorageControllers from synapse.streams.events import EventSources from synapse.types import DomainSpecificString, ISynapseReactor from synapse.util import Clock @@ -652,30 +640,6 @@ class HomeServer(metaclass=abc.ABCMeta): return UserDirectoryHandler(self) @cache_in_self - def get_groups_local_handler( - self, - ) -> Union[GroupsLocalWorkerHandler, GroupsLocalHandler]: - if self.config.worker.worker_app: - return GroupsLocalWorkerHandler(self) - else: - return GroupsLocalHandler(self) - - @cache_in_self - def get_groups_server_handler(self): - if self.config.worker.worker_app: - return GroupsServerWorkerHandler(self) - else: - return GroupsServerHandler(self) - - @cache_in_self - def get_groups_attestation_signing(self) -> GroupAttestationSigning: - return GroupAttestationSigning(self) - - @cache_in_self - def get_groups_attestation_renewer(self) -> GroupAttestionRenewer: - return GroupAttestionRenewer(self) - - @cache_in_self def get_stats_handler(self) -> StatsHandler: return StatsHandler(self) @@ -766,8 +730,8 @@ class HomeServer(metaclass=abc.ABCMeta): return PasswordPolicyHandler(self) @cache_in_self - def get_storage(self) -> Storage: - return Storage(self, self.get_datastores()) + def get_storage_controllers(self) -> StorageControllers: + return StorageControllers(self, self.get_datastores()) @cache_in_self def get_replication_streamer(self) -> ReplicationStreamer: diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 4b4ed42cff..bf09f5128a 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -127,10 +127,10 @@ class StateHandler: def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() self.store = hs.get_datastores().main - self.state_store = hs.get_storage().state + self._state_storage_controller = hs.get_storage_controllers().state self.hs = hs self._state_resolution_handler = hs.get_state_resolution_handler() - self._storage = hs.get_storage() + self._storage_controllers = hs.get_storage_controllers() @overload async def get_current_state( @@ -261,7 +261,7 @@ class StateHandler: async def compute_event_context( self, event: EventBase, - old_state: Optional[Iterable[EventBase]] = None, + state_ids_before_event: Optional[StateMap[str]] = None, partial_state: bool = False, ) -> EventContext: """Build an EventContext structure for a non-outlier event. @@ -273,12 +273,12 @@ class StateHandler: Args: event: - old_state: The state at the event if it can't be - calculated from existing events. This is normally only specified - when receiving an event from federation where we don't have the - prev events for, e.g. when backfilling. - partial_state: True if `old_state` is partial and omits non-critical - membership events + state_ids_before_event: The event ids of the state before the event if + it can't be calculated from existing events. This is normally + only specified when receiving an event from federation where we + don't have the prev events, e.g. when backfilling. + partial_state: True if `state_ids_before_event` is partial and omits + non-critical membership events Returns: The event context. """ @@ -286,13 +286,11 @@ class StateHandler: assert not event.internal_metadata.is_outlier() # - # first of all, figure out the state before the event + # first of all, figure out the state before the event, unless we + # already have it. # - if old_state: + if state_ids_before_event: # if we're given the state before the event, then we use that - state_ids_before_event: StateMap[str] = { - (s.type, s.state_key): s.event_id for s in old_state - } state_group_before_event = None state_group_before_event_prev_group = None deltas_to_state_group_before_event = None @@ -339,12 +337,14 @@ class StateHandler: # if not state_group_before_event: - state_group_before_event = await self.state_store.store_state_group( - event.event_id, - event.room_id, - prev_group=state_group_before_event_prev_group, - delta_ids=deltas_to_state_group_before_event, - current_state_ids=state_ids_before_event, + state_group_before_event = ( + await self._state_storage_controller.store_state_group( + event.event_id, + event.room_id, + prev_group=state_group_before_event_prev_group, + delta_ids=deltas_to_state_group_before_event, + current_state_ids=state_ids_before_event, + ) ) # Assign the new state group to the cached state entry. @@ -361,7 +361,7 @@ class StateHandler: if not event.is_state(): return EventContext.with_state( - storage=self._storage, + storage=self._storage_controllers, state_group_before_event=state_group_before_event, state_group=state_group_before_event, state_delta_due_to_event={}, @@ -384,16 +384,18 @@ class StateHandler: state_ids_after_event[key] = event.event_id delta_ids = {key: event.event_id} - state_group_after_event = await self.state_store.store_state_group( - event.event_id, - event.room_id, - prev_group=state_group_before_event, - delta_ids=delta_ids, - current_state_ids=state_ids_after_event, + state_group_after_event = ( + await self._state_storage_controller.store_state_group( + event.event_id, + event.room_id, + prev_group=state_group_before_event, + delta_ids=delta_ids, + current_state_ids=state_ids_after_event, + ) ) return EventContext.with_state( - storage=self._storage, + storage=self._storage_controllers, state_group=state_group_after_event, state_group_before_event=state_group_before_event, state_delta_due_to_event=delta_ids, @@ -418,7 +420,9 @@ class StateHandler: """ logger.debug("resolve_state_groups event_ids %s", event_ids) - state_groups = await self.state_store.get_state_group_for_events(event_ids) + state_groups = await self._state_storage_controller.get_state_group_for_events( + event_ids + ) state_group_ids = state_groups.values() @@ -426,8 +430,13 @@ class StateHandler: state_group_ids_set = set(state_group_ids) if len(state_group_ids_set) == 1: (state_group_id,) = state_group_ids_set - state = await self.state_store.get_state_for_groups(state_group_ids_set) - prev_group, delta_ids = await self.state_store.get_state_group_delta( + state = await self._state_storage_controller.get_state_for_groups( + state_group_ids_set + ) + ( + prev_group, + delta_ids, + ) = await self._state_storage_controller.get_state_group_delta( state_group_id ) return _StateCacheEntry( @@ -441,7 +450,7 @@ class StateHandler: room_version = await self.store.get_room_version_id(room_id) - state_to_resolve = await self.state_store.get_state_for_groups( + state_to_resolve = await self._state_storage_controller.get_state_for_groups( state_group_ids_set ) diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 105e4e1fec..bac21ecf9c 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -18,41 +18,20 @@ The storage layer is split up into multiple parts to allow Synapse to run against different configurations of databases (e.g. single or multiple databases). The `DatabasePool` class represents connections to a single physical database. The `databases` are classes that talk directly to a `DatabasePool` -instance and have associated schemas, background updates, etc. On top of those -there are classes that provide high level interfaces that combine calls to -multiple `databases`. +instance and have associated schemas, background updates, etc. + +On top of the databases are the StorageControllers, located in the +`synapse.storage.controllers` module. These classes provide high level +interfaces that combine calls to multiple `databases`. They are bundled into the +`StorageControllers` singleton for ease of use, and exposed via +`HomeServer.get_storage_controllers()`. There are also schemas that get applied to every database, regardless of the data stores associated with them (e.g. the schema version tables), which are stored in `synapse.storage.schema`. """ -from typing import TYPE_CHECKING from synapse.storage.databases import Databases from synapse.storage.databases.main import DataStore -from synapse.storage.persist_events import EventsPersistenceStorage -from synapse.storage.purge_events import PurgeEventsStorage -from synapse.storage.state import StateGroupStorage - -if TYPE_CHECKING: - from synapse.server import HomeServer - __all__ = ["Databases", "DataStore"] - - -class Storage: - """The high level interfaces for talking to various storage layers.""" - - def __init__(self, hs: "HomeServer", stores: Databases): - # We include the main data store here mainly so that we don't have to - # rewrite all the existing code to split it into high vs low level - # interfaces. - self.main = stores.main - - self.purge_events = PurgeEventsStorage(hs, stores) - self.state = StateGroupStorage(hs, stores) - - self.persistence = None - if stores.persist_events: - self.persistence = EventsPersistenceStorage(hs, stores) diff --git a/synapse/storage/controllers/__init__.py b/synapse/storage/controllers/__init__.py new file mode 100644 index 0000000000..992261d07b --- /dev/null +++ b/synapse/storage/controllers/__init__.py @@ -0,0 +1,46 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from synapse.storage.controllers.persist_events import ( + EventsPersistenceStorageController, +) +from synapse.storage.controllers.purge_events import PurgeEventsStorageController +from synapse.storage.controllers.state import StateGroupStorageController +from synapse.storage.databases import Databases +from synapse.storage.databases.main import DataStore + +if TYPE_CHECKING: + from synapse.server import HomeServer + + +__all__ = ["Databases", "DataStore"] + + +class StorageControllers: + """The high level interfaces for talking to various storage controller layers.""" + + def __init__(self, hs: "HomeServer", stores: Databases): + # We include the main data store here mainly so that we don't have to + # rewrite all the existing code to split it into high vs low level + # interfaces. + self.main = stores.main + + self.purge_events = PurgeEventsStorageController(hs, stores) + self.state = StateGroupStorageController(hs, stores) + + self.persistence = None + if stores.persist_events: + self.persistence = EventsPersistenceStorageController(hs, stores) diff --git a/synapse/storage/persist_events.py b/synapse/storage/controllers/persist_events.py index 0fc282866b..ef8c135b12 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/controllers/persist_events.py @@ -272,7 +272,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]): pass -class EventsPersistenceStorage: +class EventsPersistenceStorageController: """High level interface for handling persisting newly received events. Takes care of batching up events by room, and calculating the necessary @@ -313,7 +313,7 @@ class EventsPersistenceStorage: List of events persisted, the current position room stream position. The list of events persisted may not be the same as those passed in if they were deduplicated due to an event already existing that - matched the transcation ID; the existing event is returned in such + matched the transaction ID; the existing event is returned in such a case. """ partitioned: Dict[str, List[Tuple[EventBase, EventContext]]] = {} diff --git a/synapse/storage/purge_events.py b/synapse/storage/controllers/purge_events.py index 30669beb7c..9ca50d6a09 100644 --- a/synapse/storage/purge_events.py +++ b/synapse/storage/controllers/purge_events.py @@ -24,7 +24,7 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class PurgeEventsStorage: +class PurgeEventsStorageController: """High level interface for purging rooms and event history.""" def __init__(self, hs: "HomeServer", stores: Databases): diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py new file mode 100644 index 0000000000..0f09953086 --- /dev/null +++ b/synapse/storage/controllers/state.py @@ -0,0 +1,351 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import ( + TYPE_CHECKING, + Awaitable, + Collection, + Dict, + Iterable, + List, + Mapping, + Optional, + Tuple, +) + +from synapse.events import EventBase +from synapse.storage.state import StateFilter +from synapse.storage.util.partial_state_events_tracker import PartialStateEventsTracker +from synapse.types import MutableStateMap, StateMap + +if TYPE_CHECKING: + from synapse.server import HomeServer + from synapse.storage.databases import Databases + +logger = logging.getLogger(__name__) + + +class StateGroupStorageController: + """High level interface to fetching state for event.""" + + def __init__(self, hs: "HomeServer", stores: "Databases"): + self._is_mine_id = hs.is_mine_id + self.stores = stores + self._partial_state_events_tracker = PartialStateEventsTracker(stores.main) + + def notify_event_un_partial_stated(self, event_id: str) -> None: + self._partial_state_events_tracker.notify_un_partial_stated(event_id) + + async def get_state_group_delta( + self, state_group: int + ) -> Tuple[Optional[int], Optional[StateMap[str]]]: + """Given a state group try to return a previous group and a delta between + the old and the new. + + Args: + state_group: The state group used to retrieve state deltas. + + Returns: + A tuple of the previous group and a state map of the event IDs which + make up the delta between the old and new state groups. + """ + + state_group_delta = await self.stores.state.get_state_group_delta(state_group) + return state_group_delta.prev_group, state_group_delta.delta_ids + + async def get_state_groups_ids( + self, _room_id: str, event_ids: Collection[str] + ) -> Dict[int, MutableStateMap[str]]: + """Get the event IDs of all the state for the state groups for the given events + + Args: + _room_id: id of the room for these events + event_ids: ids of the events + + Returns: + dict of state_group_id -> (dict of (type, state_key) -> event id) + + Raises: + RuntimeError if we don't have a state group for one or more of the events + (ie they are outliers or unknown) + """ + if not event_ids: + return {} + + event_to_groups = await self.get_state_group_for_events(event_ids) + + groups = set(event_to_groups.values()) + group_to_state = await self.stores.state._get_state_for_groups(groups) + + return group_to_state + + async def get_state_ids_for_group( + self, state_group: int, state_filter: Optional[StateFilter] = None + ) -> StateMap[str]: + """Get the event IDs of all the state in the given state group + + Args: + state_group: A state group for which we want to get the state IDs. + state_filter: specifies the type of state event to fetch from DB, example: EventTypes.JoinRules + + Returns: + Resolves to a map of (type, state_key) -> event_id + """ + group_to_state = await self.get_state_for_groups((state_group,), state_filter) + + return group_to_state[state_group] + + async def get_state_groups( + self, room_id: str, event_ids: Collection[str] + ) -> Dict[int, List[EventBase]]: + """Get the state groups for the given list of event_ids + + Args: + room_id: ID of the room for these events. + event_ids: The event IDs to retrieve state for. + + Returns: + dict of state_group_id -> list of state events. + """ + if not event_ids: + return {} + + group_to_ids = await self.get_state_groups_ids(room_id, event_ids) + + state_event_map = await self.stores.main.get_events( + [ + ev_id + for group_ids in group_to_ids.values() + for ev_id in group_ids.values() + ], + get_prev_content=False, + ) + + return { + group: [ + state_event_map[v] + for v in event_id_map.values() + if v in state_event_map + ] + for group, event_id_map in group_to_ids.items() + } + + def _get_state_groups_from_groups( + self, groups: List[int], state_filter: StateFilter + ) -> Awaitable[Dict[int, StateMap[str]]]: + """Returns the state groups for a given set of groups, filtering on + types of state events. + + Args: + groups: list of state group IDs to query + state_filter: The state filter used to fetch state + from the database. + + Returns: + Dict of state group to state map. + """ + + return self.stores.state._get_state_groups_from_groups(groups, state_filter) + + async def get_state_for_events( + self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None + ) -> Dict[str, StateMap[EventBase]]: + """Given a list of event_ids and type tuples, return a list of state + dicts for each event. + + Args: + event_ids: The events to fetch the state of. + state_filter: The state filter used to fetch state. + + Returns: + A dict of (event_id) -> (type, state_key) -> [state_events] + + Raises: + RuntimeError if we don't have a state group for one or more of the events + (ie they are outliers or unknown) + """ + await_full_state = True + if state_filter and not state_filter.must_await_full_state(self._is_mine_id): + await_full_state = False + + event_to_groups = await self.get_state_group_for_events( + event_ids, await_full_state=await_full_state + ) + + groups = set(event_to_groups.values()) + group_to_state = await self.stores.state._get_state_for_groups( + groups, state_filter or StateFilter.all() + ) + + state_event_map = await self.stores.main.get_events( + [ev_id for sd in group_to_state.values() for ev_id in sd.values()], + get_prev_content=False, + ) + + event_to_state = { + event_id: { + k: state_event_map[v] + for k, v in group_to_state[group].items() + if v in state_event_map + } + for event_id, group in event_to_groups.items() + } + + return {event: event_to_state[event] for event in event_ids} + + async def get_state_ids_for_events( + self, + event_ids: Collection[str], + state_filter: Optional[StateFilter] = None, + ) -> Dict[str, StateMap[str]]: + """ + Get the state dicts corresponding to a list of events, containing the event_ids + of the state events (as opposed to the events themselves) + + Args: + event_ids: events whose state should be returned + state_filter: The state filter used to fetch state from the database. + + Returns: + A dict from event_id -> (type, state_key) -> event_id + + Raises: + RuntimeError if we don't have a state group for one or more of the events + (ie they are outliers or unknown) + """ + await_full_state = True + if state_filter and not state_filter.must_await_full_state(self._is_mine_id): + await_full_state = False + + event_to_groups = await self.get_state_group_for_events( + event_ids, await_full_state=await_full_state + ) + + groups = set(event_to_groups.values()) + group_to_state = await self.stores.state._get_state_for_groups( + groups, state_filter or StateFilter.all() + ) + + event_to_state = { + event_id: group_to_state[group] + for event_id, group in event_to_groups.items() + } + + return {event: event_to_state[event] for event in event_ids} + + async def get_state_for_event( + self, event_id: str, state_filter: Optional[StateFilter] = None + ) -> StateMap[EventBase]: + """ + Get the state dict corresponding to a particular event + + Args: + event_id: event whose state should be returned + state_filter: The state filter used to fetch state from the database. + + Returns: + A dict from (type, state_key) -> state_event + + Raises: + RuntimeError if we don't have a state group for the event (ie it is an + outlier or is unknown) + """ + state_map = await self.get_state_for_events( + [event_id], state_filter or StateFilter.all() + ) + return state_map[event_id] + + async def get_state_ids_for_event( + self, event_id: str, state_filter: Optional[StateFilter] = None + ) -> StateMap[str]: + """ + Get the state dict corresponding to a particular event + + Args: + event_id: event whose state should be returned + state_filter: The state filter used to fetch state from the database. + + Returns: + A dict from (type, state_key) -> state_event_id + + Raises: + RuntimeError if we don't have a state group for the event (ie it is an + outlier or is unknown) + """ + state_map = await self.get_state_ids_for_events( + [event_id], state_filter or StateFilter.all() + ) + return state_map[event_id] + + def get_state_for_groups( + self, groups: Iterable[int], state_filter: Optional[StateFilter] = None + ) -> Awaitable[Dict[int, MutableStateMap[str]]]: + """Gets the state at each of a list of state groups, optionally + filtering by type/state_key + + Args: + groups: list of state groups for which we want to get the state. + state_filter: The state filter used to fetch state. + from the database. + + Returns: + Dict of state group to state map. + """ + return self.stores.state._get_state_for_groups( + groups, state_filter or StateFilter.all() + ) + + async def get_state_group_for_events( + self, + event_ids: Collection[str], + await_full_state: bool = True, + ) -> Mapping[str, int]: + """Returns mapping event_id -> state_group + + Args: + event_ids: events to get state groups for + await_full_state: if true, will block if we do not yet have complete + state at these events. + """ + if await_full_state: + await self._partial_state_events_tracker.await_full_state(event_ids) + + return await self.stores.main._get_state_group_for_events(event_ids) + + async def store_state_group( + self, + event_id: str, + room_id: str, + prev_group: Optional[int], + delta_ids: Optional[StateMap[str]], + current_state_ids: StateMap[str], + ) -> int: + """Store a new set of state, returning a newly assigned state group. + + Args: + event_id: The event ID for which the state was calculated. + room_id: ID of the room for which the state was calculated. + prev_group: A previous state group for the room, optional. + delta_ids: The delta between state at `prev_group` and + `current_state_ids`, if `prev_group` was given. Same format as + `current_state_ids`. + current_state_ids: The state to store. Map of (type, state_key) + to event_id. + + Returns: + The state group ID + """ + return await self.stores.state.store_state_group( + event_id, room_id, prev_group, delta_ids, current_state_ids + ) diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 2df4dd4ed4..d900064c07 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -28,6 +28,7 @@ from typing import ( cast, ) +from synapse.api.constants import EduTypes from synapse.api.errors import Codes, StoreError from synapse.logging.opentracing import ( get_active_span_text_map, @@ -419,7 +420,7 @@ class DeviceWorkerStore(SQLBaseStore): # Add the updated cross-signing keys to the results list for user_id, result in cross_signing_keys_by_user.items(): result["user_id"] = user_id - results.append(("m.signing_key_update", result)) + results.append((EduTypes.SIGNING_KEY_UPDATE, result)) # also send the unstable version # FIXME: remove this when enough servers have upgraded # and remove the length budgeting above. @@ -545,7 +546,7 @@ class DeviceWorkerStore(SQLBaseStore): else: result["deleted"] = True - results.append(("m.device_list_update", result)) + results.append((EduTypes.DEVICE_LIST_UPDATE, result)) return results @@ -1153,6 +1154,45 @@ class DeviceWorkerStore(SQLBaseStore): _prune_txn, ) + async def get_local_devices_not_accessed_since( + self, since_ms: int + ) -> Dict[str, List[str]]: + """Retrieves local devices that haven't been accessed since a given date. + + Args: + since_ms: the timestamp to select on, every device with a last access date + from before that time is returned. + + Returns: + A dictionary with an entry for each user with at least one device matching + the request, which value is a list of the device ID(s) for the corresponding + device(s). + """ + + def get_devices_not_accessed_since_txn( + txn: LoggingTransaction, + ) -> List[Dict[str, str]]: + sql = """ + SELECT user_id, device_id + FROM devices WHERE last_seen < ? AND hidden = FALSE + """ + txn.execute(sql, (since_ms,)) + return self.db_pool.cursor_to_dict(txn) + + rows = await self.db_pool.runInteraction( + "get_devices_not_accessed_since", + get_devices_not_accessed_since_txn, + ) + + devices: Dict[str, List[str]] = {} + for row in rows: + # Remote devices are never stale from our point of view. + if self.hs.is_mine_id(row["user_id"]): + user_devices = devices.setdefault(row["user_id"], []) + user_devices.append(row["device_id"]) + + return devices + class DeviceBackgroundUpdateStore(SQLBaseStore): def __init__( diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index dcfe8caf47..eec55b6478 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -1057,7 +1057,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas INNER JOIN batch_events AS c ON i.next_batch_id = c.batch_id /* Get the depth of the batch start event from the events table */ - INNER JOIN events AS e USING (event_id) + INNER JOIN events AS e ON c.event_id = e.event_id /* Find an insertion event which matches the given event_id */ WHERE i.event_id = ? LIMIT ? @@ -1318,17 +1318,14 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas query = ( "SELECT prev_event_id FROM event_edges " - "WHERE room_id = ? AND event_id = ? AND is_state = ? " + "WHERE event_id = ? AND NOT is_state " "LIMIT ?" ) while front and len(event_results) < limit: new_front = set() for event_id in front: - txn.execute( - query, (room_id, event_id, False, limit - len(event_results)) - ) - + txn.execute(query, (event_id, limit - len(event_results))) new_results = {t[0] for t in txn} - seen_events new_front |= new_results diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index b7c4c62222..b019979350 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -938,7 +938,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): users can still get a list of recent highlights. Args: - txn: The transcation + txn: The transaction room_id: Room ID to delete from user_id: user ID to delete for stream_ordering: The lowest stream ordering which will diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 0df8ff5395..17e35cf63e 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1828,6 +1828,10 @@ class PersistEventsStore: self.store.get_aggregation_groups_for_event.invalidate, (relation.parent_id,), ) + txn.call_after( + self.store.get_mutual_event_relations_for_rel_type.invalidate, + (relation.parent_id,), + ) if relation.rel_type == RelationTypes.REPLACE: txn.call_after( @@ -2004,6 +2008,11 @@ class PersistEventsStore: self.store._invalidate_cache_and_stream( txn, self.store.get_thread_participated, (redacted_relates_to,) ) + self.store._invalidate_cache_and_stream( + txn, + self.store.get_mutual_event_relations_for_rel_type, + (redacted_relates_to,), + ) self.db_pool.simple_delete_txn( txn, table="event_relations", keyvalues={"event_id": redacted_event_id} diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 5b22d6b452..b99b107784 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -1356,14 +1356,23 @@ class EventsWorkerStore(SQLBaseStore): Returns: The set of events we have already seen. """ - res = await self._have_seen_events_dict( - (room_id, event_id) for event_id in event_ids - ) - return {eid for ((_rid, eid), have_event) in res.items() if have_event} + + # @cachedList chomps lots of memory if you call it with a big list, so + # we break it down. However, each batch requires its own index scan, so we make + # the batches as big as possible. + + results: Set[str] = set() + for chunk in batch_iter(event_ids, 500): + r = await self._have_seen_events_dict( + [(room_id, event_id) for event_id in chunk] + ) + results.update(eid for ((_rid, eid), have_event) in r.items() if have_event) + + return results @cachedList(cached_method_name="have_seen_event", list_name="keys") async def _have_seen_events_dict( - self, keys: Iterable[Tuple[str, str]] + self, keys: Collection[Tuple[str, str]] ) -> Dict[Tuple[str, str], bool]: """Helper for have_seen_events @@ -1375,11 +1384,12 @@ class EventsWorkerStore(SQLBaseStore): cache_results = { (rid, eid) for (rid, eid) in keys if self._get_event_cache.contains((eid,)) } - results = {x: True for x in cache_results} + results = dict.fromkeys(cache_results, True) + remaining = [k for k in keys if k not in cache_results] + if not remaining: + return results - def have_seen_events_txn( - txn: LoggingTransaction, chunk: Tuple[Tuple[str, str], ...] - ) -> None: + def have_seen_events_txn(txn: LoggingTransaction) -> None: # we deliberately do *not* query the database for room_id, to make the # query an index-only lookup on `events_event_id_key`. # @@ -1387,21 +1397,17 @@ class EventsWorkerStore(SQLBaseStore): sql = "SELECT event_id FROM events AS e WHERE " clause, args = make_in_list_sql_clause( - txn.database_engine, "e.event_id", [eid for (_rid, eid) in chunk] + txn.database_engine, "e.event_id", [eid for (_rid, eid) in remaining] ) txn.execute(sql + clause, args) found_events = {eid for eid, in txn} - # ... and then we can update the results for each row in the batch - results.update({(rid, eid): (eid in found_events) for (rid, eid) in chunk}) - - # each batch requires its own index scan, so we make the batches as big as - # possible. - for chunk in batch_iter((k for k in keys if k not in cache_results), 500): - await self.db_pool.runInteraction( - "have_seen_events", have_seen_events_txn, chunk + # ... and then we can update the results for each key + results.update( + {(rid, eid): (eid in found_events) for (rid, eid) in remaining} ) + await self.db_pool.runInteraction("have_seen_events", have_seen_events_txn) return results @cached(max_entries=100000, tree=True) @@ -1922,6 +1928,18 @@ class EventsWorkerStore(SQLBaseStore): LIMIT 1 """ + # We consider any forward extremity as the latest in the room and + # not a forward gap. + # + # To expand, even though there is technically a gap at the front of + # the room where the forward extremities are, we consider those the + # latest messages in the room so asking other homeservers for more + # is useless. The new latest messages will just be federated as + # usual. + txn.execute(forward_extremity_query, (event.room_id, event.event_id)) + if txn.fetchone(): + return False + # Check to see whether the event in question is already referenced # by another event. If we don't see any edges, we're next to a # forward gap. @@ -1930,8 +1948,7 @@ class EventsWorkerStore(SQLBaseStore): /* Check to make sure the event referencing our event in question is not rejected */ LEFT JOIN rejections ON event_edges.event_id = rejections.event_id WHERE - event_edges.room_id = ? - AND event_edges.prev_event_id = ? + event_edges.prev_event_id = ? /* It's not a valid edge if the event referencing our event in * question is rejected. */ @@ -1939,25 +1956,11 @@ class EventsWorkerStore(SQLBaseStore): LIMIT 1 """ - # We consider any forward extremity as the latest in the room and - # not a forward gap. - # - # To expand, even though there is technically a gap at the front of - # the room where the forward extremities are, we consider those the - # latest messages in the room so asking other homeservers for more - # is useless. The new latest messages will just be federated as - # usual. - txn.execute(forward_extremity_query, (event.room_id, event.event_id)) - forward_extremities = txn.fetchall() - if len(forward_extremities): - return False - # If there are no forward edges to the event in question (another # event hasn't referenced this event in their prev_events), then we # assume there is a forward gap in the history. - txn.execute(forward_edge_query, (event.room_id, event.event_id)) - forward_edges = txn.fetchall() - if not len(forward_edges): + txn.execute(forward_edge_query, (event.event_id,)) + if not txn.fetchone(): return True return False diff --git a/synapse/storage/databases/main/lock.py b/synapse/storage/databases/main/lock.py index bedacaf0d7..2d7633fbd5 100644 --- a/synapse/storage/databases/main/lock.py +++ b/synapse/storage/databases/main/lock.py @@ -13,7 +13,7 @@ # limitations under the License. import logging from types import TracebackType -from typing import TYPE_CHECKING, Optional, Tuple, Type +from typing import TYPE_CHECKING, Optional, Set, Tuple, Type from weakref import WeakValueDictionary from twisted.internet.interfaces import IReactorCore @@ -84,6 +84,8 @@ class LockStore(SQLBaseStore): self._on_shutdown, ) + self._acquiring_locks: Set[Tuple[str, str]] = set() + @wrap_as_background_process("LockStore._on_shutdown") async def _on_shutdown(self) -> None: """Called when the server is shutting down""" @@ -103,6 +105,21 @@ class LockStore(SQLBaseStore): context manager if the lock is successfully acquired, which *must* be used (otherwise the lock will leak). """ + if (lock_name, lock_key) in self._acquiring_locks: + return None + try: + self._acquiring_locks.add((lock_name, lock_key)) + return await self._try_acquire_lock(lock_name, lock_key) + finally: + self._acquiring_locks.discard((lock_name, lock_key)) + + async def _try_acquire_lock( + self, lock_name: str, lock_key: str + ) -> Optional["Lock"]: + """Try to acquire a lock for the given name/key. Will return an async + context manager if the lock is successfully acquired, which *must* be + used (otherwise the lock will leak). + """ # Check if this process has taken out a lock and if it's still valid. lock = self._live_tokens.get((lock_name, lock_key)) diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py index 5beb8f1d4b..9a63f953fb 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -122,6 +122,51 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore): "count_users_by_service", _count_users_by_service ) + async def get_monthly_active_users_by_service( + self, start_timestamp: Optional[int] = None, end_timestamp: Optional[int] = None + ) -> List[Tuple[str, str]]: + """Generates list of monthly active users and their services. + Please see "get_monthly_active_count_by_service" docstring for more details + about services. + + Arguments: + start_timestamp: If specified, only include users that were first active + at or after this point + end_timestamp: If specified, only include users that were first active + at or before this point + + Returns: + A list of tuples (appservice_id, user_id). "native" is emitted as the + appservice for users that don't come from appservices (i.e. native Matrix + users). + + """ + if start_timestamp is not None and end_timestamp is not None: + where_clause = 'WHERE "timestamp" >= ? and "timestamp" <= ?' + query_params = [start_timestamp, end_timestamp] + elif start_timestamp is not None: + where_clause = 'WHERE "timestamp" >= ?' + query_params = [start_timestamp] + elif end_timestamp is not None: + where_clause = 'WHERE "timestamp" <= ?' + query_params = [end_timestamp] + else: + where_clause = "" + query_params = [] + + def _list_users(txn: LoggingTransaction) -> List[Tuple[str, str]]: + sql = f""" + SELECT COALESCE(appservice_id, 'native'), user_id + FROM monthly_active_users + LEFT JOIN users ON monthly_active_users.user_id=users.name + {where_clause}; + """ + + txn.execute(sql, query_params) + return cast(List[Tuple[str, str]], txn.fetchall()) + + return await self.db_pool.runInteraction("list_users", _list_users) + async def get_registered_reserved_users(self) -> List[str]: """Of the reserved threepids defined in config, retrieve those that are associated with registered users diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py index b47c511450..9769a18a9d 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Tuple, cast +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, cast from synapse.api.presence import PresenceState, UserPresenceState from synapse.replication.tcp.streams import PresenceStream @@ -22,6 +22,7 @@ from synapse.storage.database import ( LoggingDatabaseConnection, LoggingTransaction, ) +from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.engines import PostgresEngine from synapse.storage.types import Connection from synapse.storage.util.id_generators import ( @@ -56,7 +57,7 @@ class PresenceBackgroundUpdateStore(SQLBaseStore): ) -class PresenceStore(PresenceBackgroundUpdateStore): +class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore): def __init__( self, database: DatabasePool, @@ -281,20 +282,30 @@ class PresenceStore(PresenceBackgroundUpdateStore): True if the user should have full presence sent to them, False otherwise. """ - def _should_user_receive_full_presence_with_token_txn( - txn: LoggingTransaction, - ) -> bool: - sql = """ - SELECT 1 FROM users_to_send_full_presence_to - WHERE user_id = ? - AND presence_stream_id >= ? - """ - txn.execute(sql, (user_id, from_token)) - return bool(txn.fetchone()) + token = await self._get_full_presence_stream_token_for_user(user_id) + if token is None: + return False - return await self.db_pool.runInteraction( - "should_user_receive_full_presence_with_token", - _should_user_receive_full_presence_with_token_txn, + return from_token <= token + + @cached() + async def _get_full_presence_stream_token_for_user( + self, user_id: str + ) -> Optional[int]: + """Get the presence token corresponding to the last full presence update + for this user. + + If the user presents a sync token with a presence stream token at least + as old as the result, then we need to send them a full presence update. + + If this user has never needed a full presence update, returns `None`. + """ + return await self.db_pool.simple_select_one_onecol( + table="users_to_send_full_presence_to", + keyvalues={"user_id": user_id}, + retcol="presence_stream_id", + allow_none=True, + desc="_get_full_presence_stream_token_for_user", ) async def add_users_to_send_full_presence_to(self, user_ids: Iterable[str]) -> None: @@ -307,18 +318,28 @@ class PresenceStore(PresenceBackgroundUpdateStore): # Add user entries to the table, updating the presence_stream_id column if the user already # exists in the table. presence_stream_id = self._presence_id_gen.get_current_token() - await self.db_pool.simple_upsert_many( - table="users_to_send_full_presence_to", - key_names=("user_id",), - key_values=[(user_id,) for user_id in user_ids], - value_names=("presence_stream_id",), - # We save the current presence stream ID token along with the user ID entry so - # that when a user /sync's, even if they syncing multiple times across separate - # devices at different times, each device will receive full presence once - when - # the presence stream ID in their sync token is less than the one in the table - # for their user ID. - value_values=[(presence_stream_id,) for _ in user_ids], - desc="add_users_to_send_full_presence_to", + + def _add_users_to_send_full_presence_to(txn: LoggingTransaction) -> None: + self.db_pool.simple_upsert_many_txn( + txn, + table="users_to_send_full_presence_to", + key_names=("user_id",), + key_values=[(user_id,) for user_id in user_ids], + value_names=("presence_stream_id",), + # We save the current presence stream ID token along with the user ID entry so + # that when a user /sync's, even if they syncing multiple times across separate + # devices at different times, each device will receive full presence once - when + # the presence stream ID in their sync token is less than the one in the table + # for their user ID. + value_values=[(presence_stream_id,) for _ in user_ids], + ) + for user_id in user_ids: + self._invalidate_cache_and_stream( + txn, self._get_full_presence_stream_token_for_user, (user_id,) + ) + + return await self.db_pool.runInteraction( + "add_users_to_send_full_presence_to", _add_users_to_send_full_presence_to ) async def get_presence_for_all_users( diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py index c94d5f9f81..2353c120e9 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py @@ -322,12 +322,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): ) def _purge_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[int]: - # We *immediately* delete the room from the rooms table. This ensures - # that we don't race when persisting events (as that transaction checks - # that the room exists). - txn.execute("DELETE FROM rooms WHERE room_id = ?", (room_id,)) - - # Next, we fetch all the state groups that should be deleted, before + # First, fetch all the state groups that should be deleted, before # we delete that information. txn.execute( """ @@ -387,7 +382,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): (room_id,), ) - # and finally, the tables with an index on room_id (or no useful index) + # next, the tables with an index on room_id (or no useful index) for table in ( "current_state_events", "destination_rooms", @@ -395,8 +390,13 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): "event_forward_extremities", "event_push_actions", "event_search", + "partial_state_events", "events", + "federation_inbound_events_staging", "group_rooms", + "local_current_membership", + "partial_state_rooms_servers", + "partial_state_rooms", "receipts_graph", "receipts_linearized", "room_aliases", @@ -416,8 +416,9 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): "group_summary_rooms", "room_account_data", "room_tags", - "local_current_membership", - "federation_inbound_events_staging", + # "rooms" happens last, to keep the foreign keys in the other tables + # happy + "rooms", ): logger.info("[purge] removing %s from %s", room_id, table) txn.execute("DELETE FROM %s WHERE room_id=?" % (table,), (room_id,)) diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index ad67901cc1..d5aefe02b6 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -61,6 +61,11 @@ def _is_experimental_rule_enabled( and not experimental_config.msc3786_enabled ): return False + if ( + rule_id == "global/underride/.org.matrix.msc3772.thread_reply" + and not experimental_config.msc3772_enabled + ): + return False return True @@ -169,7 +174,7 @@ class PushRulesWorkerStore( "conditions", "actions", ), - desc="get_push_rules_enabled_for_user", + desc="get_push_rules_for_user", ) rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))) @@ -183,10 +188,10 @@ class PushRulesWorkerStore( results = await self.db_pool.simple_select_list( table="push_rules_enable", keyvalues={"user_name": user_id}, - retcols=("user_name", "rule_id", "enabled"), + retcols=("rule_id", "enabled"), desc="get_push_rules_enabled_for_user", ) - return {r["rule_id"]: False if r["enabled"] == 0 else True for r in results} + return {r["rule_id"]: bool(r["enabled"]) for r in results} async def have_push_rules_changed_for_user( self, user_id: str, last_id: int @@ -208,11 +213,7 @@ class PushRulesWorkerStore( "have_push_rules_changed", have_push_rules_changed_txn ) - @cachedList( - cached_method_name="get_push_rules_for_user", - list_name="user_ids", - num_args=1, - ) + @cachedList(cached_method_name="get_push_rules_for_user", list_name="user_ids") async def bulk_get_push_rules( self, user_ids: Collection[str] ) -> Dict[str, List[JsonDict]]: @@ -244,9 +245,7 @@ class PushRulesWorkerStore( return results @cachedList( - cached_method_name="get_push_rules_enabled_for_user", - list_name="user_ids", - num_args=1, + cached_method_name="get_push_rules_enabled_for_user", list_name="user_ids" ) async def bulk_get_push_rules_enabled( self, user_ids: Collection[str] diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index 91286c9b65..bd0cfa7f32 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -91,12 +91,6 @@ class PusherWorkerStore(SQLBaseStore): yield PusherConfig(**r) - async def user_has_pusher(self, user_id: str) -> bool: - ret = await self.db_pool.simple_select_one_onecol( - "pushers", {"user_name": user_id}, "id", allow_none=True - ) - return ret is not None - async def get_pushers_by_app_id_and_pushkey( self, app_id: str, pushkey: str ) -> Iterator[PusherConfig]: diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index d035969a31..21e954ccc1 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -26,7 +26,7 @@ from typing import ( cast, ) -from synapse.api.constants import ReceiptTypes +from synapse.api.constants import EduTypes, ReceiptTypes from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import ReceiptsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause @@ -363,7 +363,7 @@ class ReceiptsWorkerStore(SQLBaseStore): row["user_id"] ] = db_to_json(row["data"]) - return [{"type": "m.receipt", "room_id": room_id, "content": content}] + return [{"type": EduTypes.RECEIPT, "room_id": room_id, "content": content}] @cachedList( cached_method_name="_get_linearized_receipts_for_room", @@ -411,7 +411,7 @@ class ReceiptsWorkerStore(SQLBaseStore): # receipts by room, event and type. room_event = results.setdefault( row["room_id"], - {"type": "m.receipt", "room_id": row["room_id"], "content": {}}, + {"type": EduTypes.RECEIPT, "room_id": row["room_id"], "content": {}}, ) # The content is of the form: @@ -476,7 +476,7 @@ class ReceiptsWorkerStore(SQLBaseStore): # receipts by room, event and type. room_event = results.setdefault( row["room_id"], - {"type": "m.receipt", "room_id": row["room_id"], "content": {}}, + {"type": EduTypes.RECEIPT, "room_id": row["room_id"], "content": {}}, ) # The content is of the form: @@ -597,7 +597,7 @@ class ReceiptsWorkerStore(SQLBaseStore): return super().process_replication_rows(stream_name, instance_name, token, rows) - def insert_linearized_receipt_txn( + def _insert_linearized_receipt_txn( self, txn: LoggingTransaction, room_id: str, @@ -673,8 +673,11 @@ class ReceiptsWorkerStore(SQLBaseStore): lock=False, ) + # When updating a local users read receipt, remove any push actions + # which resulted from the receipt's event and all earlier events. if ( - receipt_type in (ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE) + self.hs.is_mine_id(user_id) + and receipt_type in (ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE) and stream_ordering is not None ): self._remove_old_push_actions_before_txn( # type: ignore[attr-defined] @@ -683,6 +686,44 @@ class ReceiptsWorkerStore(SQLBaseStore): return rx_ts + def _graph_to_linear( + self, txn: LoggingTransaction, room_id: str, event_ids: List[str] + ) -> str: + """ + Generate a linearized event from a list of events (i.e. a list of forward + extremities in the room). + + This should allow for calculation of the correct read receipt even if + servers have different event ordering. + + Args: + txn: The transaction + room_id: The room ID the events are in. + event_ids: The list of event IDs to linearize. + + Returns: + The linearized event ID. + """ + # TODO: Make this better. + clause, args = make_in_list_sql_clause( + self.database_engine, "event_id", event_ids + ) + + sql = """ + SELECT event_id WHERE room_id = ? AND stream_ordering IN ( + SELECT max(stream_ordering) WHERE %s + ) + """ % ( + clause, + ) + + txn.execute(sql, [room_id] + list(args)) + rows = txn.fetchall() + if rows: + return rows[0][0] + else: + raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,)) + async def insert_receipt( self, room_id: str, @@ -709,35 +750,14 @@ class ReceiptsWorkerStore(SQLBaseStore): linearized_event_id = event_ids[0] else: # we need to points in graph -> linearized form. - # TODO: Make this better. - def graph_to_linear(txn: LoggingTransaction) -> str: - clause, args = make_in_list_sql_clause( - self.database_engine, "event_id", event_ids - ) - - sql = """ - SELECT event_id WHERE room_id = ? AND stream_ordering IN ( - SELECT max(stream_ordering) WHERE %s - ) - """ % ( - clause, - ) - - txn.execute(sql, [room_id] + list(args)) - rows = txn.fetchall() - if rows: - return rows[0][0] - else: - raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,)) - linearized_event_id = await self.db_pool.runInteraction( - "insert_receipt_conv", graph_to_linear + "insert_receipt_conv", self._graph_to_linear, room_id, event_ids ) async with self._receipts_id_gen.get_next() as stream_id: # type: ignore[attr-defined] event_ts = await self.db_pool.runInteraction( "insert_linearized_receipt", - self.insert_linearized_receipt_txn, + self._insert_linearized_receipt_txn, room_id, receipt_type, user_id, @@ -758,25 +778,9 @@ class ReceiptsWorkerStore(SQLBaseStore): now - event_ts, ) - await self.insert_graph_receipt(room_id, receipt_type, user_id, event_ids, data) - - max_persisted_id = self._receipts_id_gen.get_current_token() - - return stream_id, max_persisted_id - - async def insert_graph_receipt( - self, - room_id: str, - receipt_type: str, - user_id: str, - event_ids: List[str], - data: JsonDict, - ) -> None: - assert self._can_write_to_receipts - await self.db_pool.runInteraction( "insert_graph_receipt", - self.insert_graph_receipt_txn, + self._insert_graph_receipt_txn, room_id, receipt_type, user_id, @@ -784,7 +788,11 @@ class ReceiptsWorkerStore(SQLBaseStore): data, ) - def insert_graph_receipt_txn( + max_persisted_id = self._receipts_id_gen.get_current_token() + + return stream_id, max_persisted_id + + def _insert_graph_receipt_txn( self, txn: LoggingTransaction, room_id: str, diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index fe8fded88b..b457bc189e 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -767,6 +767,59 @@ class RelationsWorkerStore(SQLBaseStore): "get_if_user_has_annotated_event", _get_if_user_has_annotated_event ) + @cached(iterable=True) + async def get_mutual_event_relations_for_rel_type( + self, event_id: str, relation_type: str + ) -> Set[Tuple[str, str]]: + raise NotImplementedError() + + @cachedList( + cached_method_name="get_mutual_event_relations_for_rel_type", + list_name="relation_types", + ) + async def get_mutual_event_relations( + self, event_id: str, relation_types: Collection[str] + ) -> Dict[str, Set[Tuple[str, str]]]: + """ + Fetch event metadata for events which related to the same event as the given event. + + If the given event has no relation information, returns an empty dictionary. + + Args: + event_id: The event ID which is targeted by relations. + relation_types: The relation types to check for mutual relations. + + Returns: + A dictionary of relation type to: + A set of tuples of: + The sender + The event type + """ + rel_type_sql, rel_type_args = make_in_list_sql_clause( + self.database_engine, "relation_type", relation_types + ) + + sql = f""" + SELECT DISTINCT relation_type, sender, type FROM event_relations + INNER JOIN events USING (event_id) + WHERE relates_to_id = ? AND {rel_type_sql} + """ + + def _get_event_relations( + txn: LoggingTransaction, + ) -> Dict[str, Set[Tuple[str, str]]]: + txn.execute(sql, [event_id] + rel_type_args) + result: Dict[str, Set[Tuple[str, str]]] = { + rel_type: set() for rel_type in relation_types + } + for rel_type, sender, type in txn.fetchall(): + result[rel_type].add((sender, type)) + return result + + return await self.db_pool.runInteraction( + "get_event_relations", _get_event_relations + ) + class RelationsStore(RelationsWorkerStore): pass diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index ded15b92ef..10f2ceb50b 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -233,24 +233,23 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): UNION SELECT room_id from appservice_room_list """ - sql = """ + sql = f""" SELECT COUNT(*) FROM ( - %(published_sql)s + {published_sql} ) published INNER JOIN room_stats_state USING (room_id) INNER JOIN room_stats_current USING (room_id) WHERE ( - join_rules = 'public' OR join_rules = '%(knock_join_rule)s' + join_rules = '{JoinRules.PUBLIC}' + OR join_rules = '{JoinRules.KNOCK}' + OR join_rules = '{JoinRules.KNOCK_RESTRICTED}' OR history_visibility = 'world_readable' ) AND joined_members > 0 - """ % { - "published_sql": published_sql, - "knock_join_rule": JoinRules.KNOCK, - } + """ txn.execute(sql, query_args) return cast(Tuple[int], txn.fetchone())[0] @@ -369,29 +368,29 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): if where_clauses: where_clause = " AND " + " AND ".join(where_clauses) - sql = """ + dir = "DESC" if forwards else "ASC" + sql = f""" SELECT room_id, name, topic, canonical_alias, joined_members, avatar, history_visibility, guest_access, join_rules FROM ( - %(published_sql)s + {published_sql} ) published INNER JOIN room_stats_state USING (room_id) INNER JOIN room_stats_current USING (room_id) WHERE ( - join_rules = 'public' OR join_rules = '%(knock_join_rule)s' + join_rules = '{JoinRules.PUBLIC}' + OR join_rules = '{JoinRules.KNOCK}' + OR join_rules = '{JoinRules.KNOCK_RESTRICTED}' OR history_visibility = 'world_readable' ) AND joined_members > 0 - %(where_clause)s - ORDER BY joined_members %(dir)s, room_id %(dir)s - """ % { - "published_sql": published_sql, - "where_clause": where_clause, - "dir": "DESC" if forwards else "ASC", - "knock_join_rule": JoinRules.KNOCK, - } + {where_clause} + ORDER BY + joined_members {dir}, + room_id {dir} + """ if limit is not None: query_args.append(limit) diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index cc528fcf2d..e222b7bd1f 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -670,6 +670,30 @@ class RoomMemberWorkerStore(EventsWorkerStore): return user_who_share_room + @cached(cache_context=True, iterable=True) + async def get_mutual_rooms_between_users( + self, user_ids: FrozenSet[str], cache_context: _CacheContext + ) -> FrozenSet[str]: + """ + Returns the set of rooms that all users in `user_ids` share. + + Args: + user_ids: A frozen set of all users to investigate and return + overlapping joined rooms for. + cache_context + """ + shared_room_ids: Optional[FrozenSet[str]] = None + for user_id in user_ids: + room_ids = await self.get_rooms_for_user( + user_id, on_invalidate=cache_context.invalidate + ) + if shared_room_ids is not None: + shared_room_ids &= room_ids + else: + shared_room_ids = room_ids + + return shared_room_ids or frozenset() + async def get_joined_users_from_context( self, event: EventBase, context: EventContext ) -> Dict[str, ProfileInfo]: diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index 18ae8aee29..a07ad85582 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -16,6 +16,8 @@ import collections.abc import logging from typing import TYPE_CHECKING, Collection, Dict, Iterable, Optional, Set, Tuple +import attr + from synapse.api.constants import EventTypes, Membership from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion @@ -26,6 +28,7 @@ from synapse.storage.database import ( DatabasePool, LoggingDatabaseConnection, LoggingTransaction, + make_in_list_sql_clause, ) from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore @@ -33,6 +36,7 @@ from synapse.storage.state import StateFilter from synapse.types import JsonDict, JsonMapping, StateMap from synapse.util.caches import intern_string from synapse.util.caches.descriptors import cached, cachedList +from synapse.util.iterutils import batch_iter if TYPE_CHECKING: from synapse.server import HomeServer @@ -43,6 +47,15 @@ logger = logging.getLogger(__name__) MAX_STATE_DELTA_HOPS = 100 +@attr.s(slots=True, frozen=True, auto_attribs=True) +class EventMetadata: + """Returned by `get_metadata_for_events`""" + + room_id: str + event_type: str + state_key: Optional[str] + + def _retrieve_and_check_room_version(room_id: str, room_version_id: str) -> RoomVersion: v = KNOWN_ROOM_VERSIONS.get(room_version_id) if not v: @@ -133,6 +146,52 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): return room_version + async def get_metadata_for_events( + self, event_ids: Collection[str] + ) -> Dict[str, EventMetadata]: + """Get some metadata (room_id, type, state_key) for the given events. + + This method is a faster alternative than fetching the full events from + the DB, and should be used when the full event is not needed. + + Returns metadata for rejected and redacted events. Events that have not + been persisted are omitted from the returned dict. + """ + + def get_metadata_for_events_txn( + txn: LoggingTransaction, + batch_ids: Collection[str], + ) -> Dict[str, EventMetadata]: + clause, args = make_in_list_sql_clause( + self.database_engine, "e.event_id", batch_ids + ) + + sql = f""" + SELECT e.event_id, e.room_id, e.type, se.state_key FROM events AS e + LEFT JOIN state_events se USING (event_id) + WHERE {clause} + """ + + txn.execute(sql, args) + return { + event_id: EventMetadata( + room_id=room_id, event_type=event_type, state_key=state_key + ) + for event_id, room_id, event_type, state_key in txn + } + + result_map: Dict[str, EventMetadata] = {} + for batch_ids in batch_iter(event_ids, 1000): + result_map.update( + await self.db_pool.runInteraction( + "get_metadata_for_events", + get_metadata_for_events_txn, + batch_ids=batch_ids, + ) + ) + + return result_map + async def get_room_predecessor(self, room_id: str) -> Optional[JsonMapping]: """Get the predecessor of an upgraded room if it exists. Otherwise return None. diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index 028db69af3..2282242e9d 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -729,49 +729,6 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): users.update(rows) return list(users) - async def get_mutual_rooms_for_users( - self, user_id: str, other_user_id: str - ) -> Set[str]: - """ - Returns the rooms that a local user shares with another local or remote user. - - Args: - user_id: The MXID of a local user - other_user_id: The MXID of the other user - - Returns: - A set of room ID's that the users share. - """ - - def _get_mutual_rooms_for_users_txn( - txn: LoggingTransaction, - ) -> List[Dict[str, str]]: - txn.execute( - """ - SELECT p1.room_id - FROM users_in_public_rooms as p1 - INNER JOIN users_in_public_rooms as p2 - ON p1.room_id = p2.room_id - AND p1.user_id = ? - AND p2.user_id = ? - UNION - SELECT room_id - FROM users_who_share_private_rooms - WHERE - user_id = ? - AND other_user_id = ? - """, - (user_id, other_user_id, user_id, other_user_id), - ) - rows = self.db_pool.cursor_to_dict(txn) - return rows - - rows = await self.db_pool.runInteraction( - "get_mutual_rooms_for_users", _get_mutual_rooms_for_users_txn - ) - - return {row["room_id"] for row in rows} - async def get_user_directory_stream_pos(self) -> Optional[int]: """ Get the stream ID of the user directory stream. diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index da98f05e03..19466150d4 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -SCHEMA_VERSION = 70 # remember to update the list below when updating +SCHEMA_VERSION = 71 # remember to update the list below when updating """Represents the expectations made by the codebase about the database schema This should be incremented whenever the codebase changes its requirements on the @@ -67,6 +67,9 @@ Changes in SCHEMA_VERSION = 69: Changes in SCHEMA_VERSION = 70: - event_reference_hashes is no longer written to. + +Changes in SCHEMA_VERSION = 71: + - event_edges.room_id is no longer read from. """ diff --git a/synapse/storage/schema/main/delta/70/01clean_table_purged_rooms.sql b/synapse/storage/schema/main/delta/70/01clean_table_purged_rooms.sql new file mode 100644 index 0000000000..aed79635b2 --- /dev/null +++ b/synapse/storage/schema/main/delta/70/01clean_table_purged_rooms.sql @@ -0,0 +1,19 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Clean up left over rows from bug #11833, which was fixed in #12770. +DELETE FROM federation_inbound_events_staging WHERE room_id not in ( + SELECT room_id FROM rooms +); diff --git a/synapse/storage/state.py b/synapse/storage/state.py index ab630953ac..96aaffb53c 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -15,7 +15,6 @@ import logging from typing import ( TYPE_CHECKING, - Awaitable, Callable, Collection, Dict, @@ -32,15 +31,11 @@ import attr from frozendict import frozendict from synapse.api.constants import EventTypes -from synapse.events import EventBase -from synapse.storage.util.partial_state_events_tracker import PartialStateEventsTracker from synapse.types import MutableStateMap, StateKey, StateMap if TYPE_CHECKING: from typing import FrozenSet # noqa: used within quoted type hint; flake8 sad - from synapse.server import HomeServer - from synapse.storage.databases import Databases logger = logging.getLogger(__name__) @@ -578,318 +573,3 @@ _ALL_NON_MEMBER_STATE_FILTER = StateFilter( types=frozendict({EventTypes.Member: frozenset()}), include_others=True ) _NONE_STATE_FILTER = StateFilter(types=frozendict(), include_others=False) - - -class StateGroupStorage: - """High level interface to fetching state for event.""" - - def __init__(self, hs: "HomeServer", stores: "Databases"): - self._is_mine_id = hs.is_mine_id - self.stores = stores - self._partial_state_events_tracker = PartialStateEventsTracker(stores.main) - - def notify_event_un_partial_stated(self, event_id: str) -> None: - self._partial_state_events_tracker.notify_un_partial_stated(event_id) - - async def get_state_group_delta( - self, state_group: int - ) -> Tuple[Optional[int], Optional[StateMap[str]]]: - """Given a state group try to return a previous group and a delta between - the old and the new. - - Args: - state_group: The state group used to retrieve state deltas. - - Returns: - A tuple of the previous group and a state map of the event IDs which - make up the delta between the old and new state groups. - """ - - state_group_delta = await self.stores.state.get_state_group_delta(state_group) - return state_group_delta.prev_group, state_group_delta.delta_ids - - async def get_state_groups_ids( - self, _room_id: str, event_ids: Collection[str] - ) -> Dict[int, MutableStateMap[str]]: - """Get the event IDs of all the state for the state groups for the given events - - Args: - _room_id: id of the room for these events - event_ids: ids of the events - - Returns: - dict of state_group_id -> (dict of (type, state_key) -> event id) - - Raises: - RuntimeError if we don't have a state group for one or more of the events - (ie they are outliers or unknown) - """ - if not event_ids: - return {} - - event_to_groups = await self.get_state_group_for_events(event_ids) - - groups = set(event_to_groups.values()) - group_to_state = await self.stores.state._get_state_for_groups(groups) - - return group_to_state - - async def get_state_ids_for_group( - self, state_group: int, state_filter: Optional[StateFilter] = None - ) -> StateMap[str]: - """Get the event IDs of all the state in the given state group - - Args: - state_group: A state group for which we want to get the state IDs. - state_filter: specifies the type of state event to fetch from DB, example: EventTypes.JoinRules - - Returns: - Resolves to a map of (type, state_key) -> event_id - """ - group_to_state = await self.get_state_for_groups((state_group,), state_filter) - - return group_to_state[state_group] - - async def get_state_groups( - self, room_id: str, event_ids: Collection[str] - ) -> Dict[int, List[EventBase]]: - """Get the state groups for the given list of event_ids - - Args: - room_id: ID of the room for these events. - event_ids: The event IDs to retrieve state for. - - Returns: - dict of state_group_id -> list of state events. - """ - if not event_ids: - return {} - - group_to_ids = await self.get_state_groups_ids(room_id, event_ids) - - state_event_map = await self.stores.main.get_events( - [ - ev_id - for group_ids in group_to_ids.values() - for ev_id in group_ids.values() - ], - get_prev_content=False, - ) - - return { - group: [ - state_event_map[v] - for v in event_id_map.values() - if v in state_event_map - ] - for group, event_id_map in group_to_ids.items() - } - - def _get_state_groups_from_groups( - self, groups: List[int], state_filter: StateFilter - ) -> Awaitable[Dict[int, StateMap[str]]]: - """Returns the state groups for a given set of groups, filtering on - types of state events. - - Args: - groups: list of state group IDs to query - state_filter: The state filter used to fetch state - from the database. - - Returns: - Dict of state group to state map. - """ - - return self.stores.state._get_state_groups_from_groups(groups, state_filter) - - async def get_state_for_events( - self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None - ) -> Dict[str, StateMap[EventBase]]: - """Given a list of event_ids and type tuples, return a list of state - dicts for each event. - - Args: - event_ids: The events to fetch the state of. - state_filter: The state filter used to fetch state. - - Returns: - A dict of (event_id) -> (type, state_key) -> [state_events] - - Raises: - RuntimeError if we don't have a state group for one or more of the events - (ie they are outliers or unknown) - """ - await_full_state = True - if state_filter and not state_filter.must_await_full_state(self._is_mine_id): - await_full_state = False - - event_to_groups = await self.get_state_group_for_events( - event_ids, await_full_state=await_full_state - ) - - groups = set(event_to_groups.values()) - group_to_state = await self.stores.state._get_state_for_groups( - groups, state_filter or StateFilter.all() - ) - - state_event_map = await self.stores.main.get_events( - [ev_id for sd in group_to_state.values() for ev_id in sd.values()], - get_prev_content=False, - ) - - event_to_state = { - event_id: { - k: state_event_map[v] - for k, v in group_to_state[group].items() - if v in state_event_map - } - for event_id, group in event_to_groups.items() - } - - return {event: event_to_state[event] for event in event_ids} - - async def get_state_ids_for_events( - self, - event_ids: Collection[str], - state_filter: Optional[StateFilter] = None, - ) -> Dict[str, StateMap[str]]: - """ - Get the state dicts corresponding to a list of events, containing the event_ids - of the state events (as opposed to the events themselves) - - Args: - event_ids: events whose state should be returned - state_filter: The state filter used to fetch state from the database. - - Returns: - A dict from event_id -> (type, state_key) -> event_id - - Raises: - RuntimeError if we don't have a state group for one or more of the events - (ie they are outliers or unknown) - """ - await_full_state = True - if state_filter and not state_filter.must_await_full_state(self._is_mine_id): - await_full_state = False - - event_to_groups = await self.get_state_group_for_events( - event_ids, await_full_state=await_full_state - ) - - groups = set(event_to_groups.values()) - group_to_state = await self.stores.state._get_state_for_groups( - groups, state_filter or StateFilter.all() - ) - - event_to_state = { - event_id: group_to_state[group] - for event_id, group in event_to_groups.items() - } - - return {event: event_to_state[event] for event in event_ids} - - async def get_state_for_event( - self, event_id: str, state_filter: Optional[StateFilter] = None - ) -> StateMap[EventBase]: - """ - Get the state dict corresponding to a particular event - - Args: - event_id: event whose state should be returned - state_filter: The state filter used to fetch state from the database. - - Returns: - A dict from (type, state_key) -> state_event - - Raises: - RuntimeError if we don't have a state group for the event (ie it is an - outlier or is unknown) - """ - state_map = await self.get_state_for_events( - [event_id], state_filter or StateFilter.all() - ) - return state_map[event_id] - - async def get_state_ids_for_event( - self, event_id: str, state_filter: Optional[StateFilter] = None - ) -> StateMap[str]: - """ - Get the state dict corresponding to a particular event - - Args: - event_id: event whose state should be returned - state_filter: The state filter used to fetch state from the database. - - Returns: - A dict from (type, state_key) -> state_event_id - - Raises: - RuntimeError if we don't have a state group for the event (ie it is an - outlier or is unknown) - """ - state_map = await self.get_state_ids_for_events( - [event_id], state_filter or StateFilter.all() - ) - return state_map[event_id] - - def get_state_for_groups( - self, groups: Iterable[int], state_filter: Optional[StateFilter] = None - ) -> Awaitable[Dict[int, MutableStateMap[str]]]: - """Gets the state at each of a list of state groups, optionally - filtering by type/state_key - - Args: - groups: list of state groups for which we want to get the state. - state_filter: The state filter used to fetch state. - from the database. - - Returns: - Dict of state group to state map. - """ - return self.stores.state._get_state_for_groups( - groups, state_filter or StateFilter.all() - ) - - async def get_state_group_for_events( - self, - event_ids: Collection[str], - await_full_state: bool = True, - ) -> Mapping[str, int]: - """Returns mapping event_id -> state_group - - Args: - event_ids: events to get state groups for - await_full_state: if true, will block if we do not yet have complete - state at these events. - """ - if await_full_state: - await self._partial_state_events_tracker.await_full_state(event_ids) - - return await self.stores.main._get_state_group_for_events(event_ids) - - async def store_state_group( - self, - event_id: str, - room_id: str, - prev_group: Optional[int], - delta_ids: Optional[StateMap[str]], - current_state_ids: StateMap[str], - ) -> int: - """Store a new set of state, returning a newly assigned state group. - - Args: - event_id: The event ID for which the state was calculated. - room_id: ID of the room for which the state was calculated. - prev_group: A previous state group for the room, optional. - delta_ids: The delta between state at `prev_group` and - `current_state_ids`, if `prev_group` was given. Same format as - `current_state_ids`. - current_state_ids: The state to store. Map of (type, state_key) - to event_id. - - Returns: - The state group ID - """ - return await self.stores.state.store_state_group( - event_id, room_id, prev_group, delta_ids, current_state_ids - ) diff --git a/synapse/streams/events.py b/synapse/streams/events.py index acf17ba623..54e0b1a23b 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -54,7 +54,6 @@ class EventSources: push_rules_key = self.store.get_max_push_rules_stream_id() to_device_key = self.store.get_to_device_stream_token() device_list_key = self.store.get_device_stream_token() - groups_key = self.store.get_group_stream_token() token = StreamToken( room_key=self.sources.room.get_current_key(), @@ -65,7 +64,8 @@ class EventSources: push_rules_key=push_rules_key, to_device_key=to_device_key, device_list_key=device_list_key, - groups_key=groups_key, + # Groups key is unused. + groups_key=0, ) return token diff --git a/synapse/types.py b/synapse/types.py index 6f7128ddd6..0586d2cbb9 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -320,29 +320,6 @@ class EventID(DomainSpecificString): SIGIL = "$" -@attr.s(slots=True, frozen=True, repr=False) -class GroupID(DomainSpecificString): - """Structure representing a group ID.""" - - SIGIL = "+" - - @classmethod - def from_string(cls: Type[DS], s: str) -> DS: - group_id: DS = super().from_string(s) # type: ignore - - if not group_id.localpart: - raise SynapseError(400, "Group ID cannot be empty", Codes.INVALID_PARAM) - - if contains_invalid_mxid_characters(group_id.localpart): - raise SynapseError( - 400, - "Group ID can only contain characters a-z, 0-9, or '=_-./'", - Codes.INVALID_PARAM, - ) - - return group_id - - mxid_localpart_allowed_characters = set( "_-./=" + string.ascii_lowercase + string.digits ) @@ -662,7 +639,7 @@ class StreamToken: 6. `push_rules_key`: `541479` 7. `to_device_key`: `274711` 8. `device_list_key`: `265584` - 9. `groups_key`: `1` + 9. `groups_key`: `1` (note that this key is now unused) You can see how many of these keys correspond to the various fields in a "/sync" response: @@ -714,6 +691,7 @@ class StreamToken: push_rules_key: int to_device_key: int device_list_key: int + # Note that the groups key is no longer used and may have bogus values. groups_key: int _SEPARATOR = "_" @@ -745,6 +723,9 @@ class StreamToken: str(self.push_rules_key), str(self.to_device_key), str(self.device_list_key), + # Note that the groups key is no longer used, but it is still + # serialized so that there will not be confusion in the future + # if additional tokens are added. str(self.groups_key), ] ) diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index eda92d864d..867f315b2a 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -595,13 +595,14 @@ def cached( def cachedList( *, cached_method_name: str, list_name: str, num_args: Optional[int] = None ) -> Callable[[F], _CachedFunction[F]]: - """Creates a descriptor that wraps a function in a `CacheListDescriptor`. + """Creates a descriptor that wraps a function in a `DeferredCacheListDescriptor`. - Used to do batch lookups for an already created cache. A single argument + Used to do batch lookups for an already created cache. One of the arguments is specified as a list that is iterated through to lookup keys in the original cache. A new tuple consisting of the (deduplicated) keys that weren't in - the cache gets passed to the original function, the result of which is stored in the - cache. + the cache gets passed to the original function, which is expected to results + in a map of key to value for each passed value. THe new results are stored in the + original cache. Note that any missing values are cached as None. Args: cached_method_name: The name of the single-item lookup method. @@ -614,11 +615,11 @@ def cachedList( Example: class Example: - @cached(num_args=2) - def do_something(self, first_arg): + @cached() + def do_something(self, first_arg, second_arg): ... - @cachedList(do_something.cache, list_name="second_args", num_args=2) + @cachedList(cached_method_name="do_something", list_name="second_args") def batch_do_something(self, first_arg, second_args): ... """ diff --git a/synapse/visibility.py b/synapse/visibility.py index da4af02796..97548c14e3 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -20,7 +20,7 @@ from typing_extensions import Final from synapse.api.constants import EventTypes, HistoryVisibility, Membership from synapse.events import EventBase from synapse.events.utils import prune_event -from synapse.storage import Storage +from synapse.storage.controllers import StorageControllers from synapse.storage.state import StateFilter from synapse.types import RetentionPolicy, StateMap, get_domain_from_id @@ -47,7 +47,7 @@ _HISTORY_VIS_KEY: Final[Tuple[str, str]] = (EventTypes.RoomHistoryVisibility, "" async def filter_events_for_client( - storage: Storage, + storage: StorageControllers, user_id: str, events: List[EventBase], is_peeking: bool = False, @@ -268,7 +268,7 @@ async def filter_events_for_client( async def filter_events_for_server( - storage: Storage, + storage: StorageControllers, server_name: str, events: List[EventBase], redact: bool = True, @@ -360,7 +360,7 @@ async def filter_events_for_server( async def _event_to_history_vis( - storage: Storage, events: Collection[EventBase] + storage: StorageControllers, events: Collection[EventBase] ) -> Dict[str, str]: """Get the history visibility at each of the given events @@ -407,7 +407,7 @@ async def _event_to_history_vis( async def _event_to_memberships( - storage: Storage, events: Collection[EventBase], server_name: str + storage: StorageControllers, events: Collection[EventBase], server_name: str ) -> Dict[str, StateMap[EventBase]]: """Get the remote membership list at each of the given events |