summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorPatrick Cloke <patrickc@matrix.org>2023-10-02 09:08:23 -0400
committerPatrick Cloke <patrickc@matrix.org>2023-10-02 09:08:23 -0400
commit656ffa23c9ecbe6612aa850e4fc584d799984af2 (patch)
tree3414e729d9a2ffa042cec0d848fdc1361a35d8c8 /synapse
parentRevert "Temporarily disable webp thumbnailing" (diff)
parentRemove Python version from `/_synapse/admin/v1/server_version` (#16380) (diff)
downloadsynapse-656ffa23c9ecbe6612aa850e4fc584d799984af2.tar.xz
Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes
Diffstat (limited to 'synapse')
-rw-r--r--synapse/_pydantic_compat.py26
-rw-r--r--synapse/api/filtering.py8
-rw-r--r--synapse/appservice/__init__.py6
-rw-r--r--synapse/appservice/api.py6
-rw-r--r--synapse/appservice/scheduler.py18
-rw-r--r--synapse/config/_util.py10
-rw-r--r--synapse/config/experimental.py4
-rw-r--r--synapse/config/workers.py10
-rw-r--r--synapse/events/validator.py17
-rw-r--r--synapse/federation/federation_client.py4
-rw-r--r--synapse/federation/federation_server.py76
-rw-r--r--synapse/handlers/appservice.py9
-rw-r--r--synapse/handlers/e2e_keys.py24
-rw-r--r--synapse/handlers/federation_event.py8
-rw-r--r--synapse/handlers/initial_sync.py3
-rw-r--r--synapse/handlers/message.py5
-rw-r--r--synapse/handlers/presence.py33
-rw-r--r--synapse/handlers/receipts.py13
-rw-r--r--synapse/handlers/relations.py14
-rw-r--r--synapse/handlers/sync.py4
-rw-r--r--synapse/handlers/typing.py17
-rw-r--r--synapse/http/servlet.py11
-rw-r--r--synapse/media/_base.py42
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py2
-rw-r--r--synapse/replication/tcp/client.py8
-rw-r--r--synapse/rest/admin/__init__.py6
-rw-r--r--synapse/rest/client/account.py7
-rw-r--r--synapse/rest/client/devices.py7
-rw-r--r--synapse/rest/client/directory.py8
-rw-r--r--synapse/rest/client/filter.py4
-rw-r--r--synapse/rest/client/models.py7
-rw-r--r--synapse/rest/key/v2/remote_key_resource.py8
-rw-r--r--synapse/rest/models.py22
-rw-r--r--synapse/state/v2.py5
-rw-r--r--synapse/storage/background_updates.py7
-rw-r--r--synapse/storage/controllers/state.py61
-rw-r--r--synapse/storage/database.py14
-rw-r--r--synapse/storage/databases/main/appservice.py6
-rw-r--r--synapse/storage/databases/main/devices.py23
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py25
-rw-r--r--synapse/storage/databases/main/events_worker.py5
-rw-r--r--synapse/storage/databases/main/filtering.py4
-rw-r--r--synapse/storage/databases/main/keys.py6
-rw-r--r--synapse/storage/databases/main/presence.py14
-rw-r--r--synapse/storage/databases/main/push_rule.py3
-rw-r--r--synapse/storage/databases/main/receipts.py14
-rw-r--r--synapse/storage/databases/main/relations.py10
-rw-r--r--synapse/storage/databases/main/roommember.py18
-rw-r--r--synapse/storage/databases/main/state.py14
-rw-r--r--synapse/storage/databases/main/transactions.py4
-rw-r--r--synapse/storage/databases/main/user_erasure_store.py4
-rw-r--r--synapse/storage/databases/state/bg_updates.py18
-rw-r--r--synapse/storage/types.py20
53 files changed, 458 insertions, 264 deletions
diff --git a/synapse/_pydantic_compat.py b/synapse/_pydantic_compat.py
new file mode 100644

index 0000000000..ddff72afa1 --- /dev/null +++ b/synapse/_pydantic_compat.py
@@ -0,0 +1,26 @@ +# Copyright 2023 Maxwell G <maxwell@gtmx.me> +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from packaging.version import Version + +try: + from pydantic import __version__ as pydantic_version +except ImportError: + import importlib.metadata + + pydantic_version = importlib.metadata.version("pydantic") + +HAS_PYDANTIC_V2: bool = Version(pydantic_version).major == 2 + +__all__ = ("HAS_PYDANTIC_V2",) diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index 0995ecbe83..74ee8e9f3f 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py
@@ -37,7 +37,7 @@ from synapse.api.constants import EduTypes, EventContentFields from synapse.api.errors import SynapseError from synapse.api.presence import UserPresenceState from synapse.events import EventBase, relation_from_event -from synapse.types import JsonDict, RoomID, UserID +from synapse.types import JsonDict, JsonMapping, RoomID, UserID if TYPE_CHECKING: from synapse.server import HomeServer @@ -191,7 +191,7 @@ FilterEvent = TypeVar("FilterEvent", EventBase, UserPresenceState, JsonDict) class FilterCollection: - def __init__(self, hs: "HomeServer", filter_json: JsonDict): + def __init__(self, hs: "HomeServer", filter_json: JsonMapping): self._filter_json = filter_json room_filter_json = self._filter_json.get("room", {}) @@ -219,7 +219,7 @@ class FilterCollection: def __repr__(self) -> str: return "<FilterCollection %s>" % (json.dumps(self._filter_json),) - def get_filter_json(self) -> JsonDict: + def get_filter_json(self) -> JsonMapping: return self._filter_json def timeline_limit(self) -> int: @@ -313,7 +313,7 @@ class FilterCollection: class Filter: - def __init__(self, hs: "HomeServer", filter_json: JsonDict): + def __init__(self, hs: "HomeServer", filter_json: JsonMapping): self._hs = hs self._store = hs.get_datastores().main self.filter_json = filter_json diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py
index 2260a8f589..6f4aa53c93 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py
@@ -23,7 +23,7 @@ from netaddr import IPSet from synapse.api.constants import EventTypes from synapse.events import EventBase -from synapse.types import DeviceListUpdates, JsonDict, UserID +from synapse.types import DeviceListUpdates, JsonDict, JsonMapping, UserID from synapse.util.caches.descriptors import _CacheContext, cached if TYPE_CHECKING: @@ -379,8 +379,8 @@ class AppServiceTransaction: service: ApplicationService, id: int, events: Sequence[EventBase], - ephemeral: List[JsonDict], - to_device_messages: List[JsonDict], + ephemeral: List[JsonMapping], + to_device_messages: List[JsonMapping], one_time_keys_count: TransactionOneTimeKeysCount, unused_fallback_keys: TransactionUnusedFallbackKeys, device_list_summary: DeviceListUpdates, diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index b1523be208..c42e1f11aa 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py
@@ -41,7 +41,7 @@ from synapse.events import EventBase from synapse.events.utils import SerializeEventConfig, serialize_event from synapse.http.client import SimpleHttpClient, is_unknown_endpoint from synapse.logging import opentracing -from synapse.types import DeviceListUpdates, JsonDict, ThirdPartyInstanceID +from synapse.types import DeviceListUpdates, JsonDict, JsonMapping, ThirdPartyInstanceID from synapse.util.caches.response_cache import ResponseCache if TYPE_CHECKING: @@ -306,8 +306,8 @@ class ApplicationServiceApi(SimpleHttpClient): self, service: "ApplicationService", events: Sequence[EventBase], - ephemeral: List[JsonDict], - to_device_messages: List[JsonDict], + ephemeral: List[JsonMapping], + to_device_messages: List[JsonMapping], one_time_keys_count: TransactionOneTimeKeysCount, unused_fallback_keys: TransactionUnusedFallbackKeys, device_list_summary: DeviceListUpdates, diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py
index 79f95f7653..18a30bc376 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py
@@ -73,7 +73,7 @@ from synapse.events import EventBase from synapse.logging.context import run_in_background from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.databases.main import DataStore -from synapse.types import DeviceListUpdates, JsonDict +from synapse.types import DeviceListUpdates, JsonMapping from synapse.util import Clock if TYPE_CHECKING: @@ -121,8 +121,8 @@ class ApplicationServiceScheduler: self, appservice: ApplicationService, events: Optional[Collection[EventBase]] = None, - ephemeral: Optional[Collection[JsonDict]] = None, - to_device_messages: Optional[Collection[JsonDict]] = None, + ephemeral: Optional[Collection[JsonMapping]] = None, + to_device_messages: Optional[Collection[JsonMapping]] = None, device_list_summary: Optional[DeviceListUpdates] = None, ) -> None: """ @@ -180,9 +180,9 @@ class _ServiceQueuer: # dict of {service_id: [events]} self.queued_events: Dict[str, List[EventBase]] = {} # dict of {service_id: [events]} - self.queued_ephemeral: Dict[str, List[JsonDict]] = {} + self.queued_ephemeral: Dict[str, List[JsonMapping]] = {} # dict of {service_id: [to_device_message_json]} - self.queued_to_device_messages: Dict[str, List[JsonDict]] = {} + self.queued_to_device_messages: Dict[str, List[JsonMapping]] = {} # dict of {service_id: [device_list_summary]} self.queued_device_list_summaries: Dict[str, List[DeviceListUpdates]] = {} @@ -293,8 +293,8 @@ class _ServiceQueuer: self, service: ApplicationService, events: Iterable[EventBase], - ephemerals: Iterable[JsonDict], - to_device_messages: Iterable[JsonDict], + ephemerals: Iterable[JsonMapping], + to_device_messages: Iterable[JsonMapping], ) -> Tuple[TransactionOneTimeKeysCount, TransactionUnusedFallbackKeys]: """ Given a list of the events, ephemeral messages and to-device messages, @@ -364,8 +364,8 @@ class _TransactionController: self, service: ApplicationService, events: Sequence[EventBase], - ephemeral: Optional[List[JsonDict]] = None, - to_device_messages: Optional[List[JsonDict]] = None, + ephemeral: Optional[List[JsonMapping]] = None, + to_device_messages: Optional[List[JsonMapping]] = None, one_time_keys_count: Optional[TransactionOneTimeKeysCount] = None, unused_fallback_keys: Optional[TransactionUnusedFallbackKeys] = None, device_list_summary: Optional[DeviceListUpdates] = None, diff --git a/synapse/config/_util.py b/synapse/config/_util.py
index acccca413b..746838eee3 100644 --- a/synapse/config/_util.py +++ b/synapse/config/_util.py
@@ -11,10 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Type, TypeVar +from typing import TYPE_CHECKING, Any, Dict, Type, TypeVar import jsonschema -from pydantic import BaseModel, ValidationError, parse_obj_as + +from synapse._pydantic_compat import HAS_PYDANTIC_V2 + +if TYPE_CHECKING or HAS_PYDANTIC_V2: + from pydantic.v1 import BaseModel, ValidationError, parse_obj_as +else: + from pydantic import BaseModel, ValidationError, parse_obj_as from synapse.config._base import ConfigError from synapse.types import JsonDict, StrSequence diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index cabe0d4397..9f830e7094 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py
@@ -415,3 +415,7 @@ class ExperimentalConfig(Config): LimitExceededError.include_retry_after_header = experimental.get( "msc4041_enabled", False ) + + self.msc4028_push_encrypted_events = experimental.get( + "msc4028_push_encrypted_events", False + ) diff --git a/synapse/config/workers.py b/synapse/config/workers.py
index 6567fb6bb0..f1766088fc 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py
@@ -15,10 +15,16 @@ import argparse import logging -from typing import Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union import attr -from pydantic import BaseModel, Extra, StrictBool, StrictInt, StrictStr + +from synapse._pydantic_compat import HAS_PYDANTIC_V2 + +if TYPE_CHECKING or HAS_PYDANTIC_V2: + from pydantic.v1 import BaseModel, Extra, StrictBool, StrictInt, StrictStr +else: + from pydantic import BaseModel, Extra, StrictBool, StrictInt, StrictStr from synapse.config._base import ( Config, diff --git a/synapse/events/validator.py b/synapse/events/validator.py
index 5da50cb0d2..83d9fb5813 100644 --- a/synapse/events/validator.py +++ b/synapse/events/validator.py
@@ -12,10 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. import collections.abc -from typing import List, Type, Union, cast +from typing import TYPE_CHECKING, List, Type, Union, cast import jsonschema -from pydantic import Field, StrictBool, StrictStr + +from synapse._pydantic_compat import HAS_PYDANTIC_V2 + +if TYPE_CHECKING or HAS_PYDANTIC_V2: + from pydantic.v1 import Field, StrictBool, StrictStr +else: + from pydantic import Field, StrictBool, StrictStr from synapse.api.constants import ( MAX_ALIAS_LENGTH, @@ -33,9 +39,9 @@ from synapse.events.utils import ( CANONICALJSON_MIN_INT, validate_canonicaljson, ) -from synapse.federation.federation_server import server_matches_acl_event from synapse.http.servlet import validate_json_object from synapse.rest.models import RequestBodyModel +from synapse.storage.controllers.state import server_acl_evaluator_from_event from synapse.types import EventID, JsonDict, RoomID, StrCollection, UserID @@ -100,7 +106,10 @@ class EventValidator: self._validate_retention(event) elif event.type == EventTypes.ServerACL: - if not server_matches_acl_event(config.server.server_name, event): + server_acl_evaluator = server_acl_evaluator_from_event(event) + if not server_acl_evaluator.server_matches_acl_event( + config.server.server_name + ): raise SynapseError( 400, "Can't create an ACL event that denies the local server" ) diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 607013f121..c8bc46415d 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py
@@ -64,7 +64,7 @@ from synapse.federation.transport.client import SendJoinResponse from synapse.http.client import is_unknown_endpoint from synapse.http.types import QueryParams from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, tag_args, trace -from synapse.types import JsonDict, UserID, get_domain_from_id +from synapse.types import JsonDict, StrCollection, UserID, get_domain_from_id from synapse.util.async_helpers import concurrently_execute from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.retryutils import NotRetryingDestination @@ -1704,7 +1704,7 @@ class FederationClient(FederationBase): async def timestamp_to_event( self, *, - destinations: List[str], + destinations: StrCollection, room_id: str, timestamp: int, direction: Direction, diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index f9915e5a3f..ec8e770430 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py
@@ -29,10 +29,8 @@ from typing import ( Union, ) -from matrix_common.regex import glob_to_regex from prometheus_client import Counter, Gauge, Histogram -from twisted.internet.abstract import isIPAddress from twisted.python import failure from synapse.api.constants import ( @@ -1324,75 +1322,13 @@ class FederationServer(FederationBase): Raises: AuthError if the server does not match the ACL """ - acl_event = await self._storage_controllers.state.get_current_state_event( - room_id, EventTypes.ServerACL, "" + server_acl_evaluator = ( + await self._storage_controllers.state.get_server_acl_for_room(room_id) ) - if not acl_event or server_matches_acl_event(server_name, acl_event): - return - - raise AuthError(code=403, msg="Server is banned from room") - - -def server_matches_acl_event(server_name: str, acl_event: EventBase) -> bool: - """Check if the given server is allowed by the ACL event - - Args: - server_name: name of server, without any port part - acl_event: m.room.server_acl event - - Returns: - True if this server is allowed by the ACLs - """ - logger.debug("Checking %s against acl %s", server_name, acl_event.content) - - # first of all, check if literal IPs are blocked, and if so, whether the - # server name is a literal IP - allow_ip_literals = acl_event.content.get("allow_ip_literals", True) - if not isinstance(allow_ip_literals, bool): - logger.warning("Ignoring non-bool allow_ip_literals flag") - allow_ip_literals = True - if not allow_ip_literals: - # check for ipv6 literals. These start with '['. - if server_name[0] == "[": - return False - - # check for ipv4 literals. We can just lift the routine from twisted. - if isIPAddress(server_name): - return False - - # next, check the deny list - deny = acl_event.content.get("deny", []) - if not isinstance(deny, (list, tuple)): - logger.warning("Ignoring non-list deny ACL %s", deny) - deny = [] - for e in deny: - if _acl_entry_matches(server_name, e): - # logger.info("%s matched deny rule %s", server_name, e) - return False - - # then the allow list. - allow = acl_event.content.get("allow", []) - if not isinstance(allow, (list, tuple)): - logger.warning("Ignoring non-list allow ACL %s", allow) - allow = [] - for e in allow: - if _acl_entry_matches(server_name, e): - # logger.info("%s matched allow rule %s", server_name, e) - return True - - # everything else should be rejected. - # logger.info("%s fell through", server_name) - return False - - -def _acl_entry_matches(server_name: str, acl_entry: Any) -> bool: - if not isinstance(acl_entry, str): - logger.warning( - "Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry) - ) - return False - regex = glob_to_regex(acl_entry) - return bool(regex.match(server_name)) + if server_acl_evaluator and not server_acl_evaluator.server_matches_acl_event( + server_name + ): + raise AuthError(code=403, msg="Server is banned from room") class FederationHandlerRegistry: diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 6429545c98..7de7bd3289 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py
@@ -46,6 +46,7 @@ from synapse.storage.databases.main.directory import RoomAliasMapping from synapse.types import ( DeviceListUpdates, JsonDict, + JsonMapping, RoomAlias, RoomStreamToken, StreamKeyType, @@ -397,7 +398,7 @@ class ApplicationServicesHandler: async def _handle_typing( self, service: ApplicationService, new_token: int - ) -> List[JsonDict]: + ) -> List[JsonMapping]: """ Return the typing events since the given stream token that the given application service should receive. @@ -432,7 +433,7 @@ class ApplicationServicesHandler: async def _handle_receipts( self, service: ApplicationService, new_token: int - ) -> List[JsonDict]: + ) -> List[JsonMapping]: """ Return the latest read receipts that the given application service should receive. @@ -471,7 +472,7 @@ class ApplicationServicesHandler: service: ApplicationService, users: Collection[Union[str, UserID]], new_token: Optional[int], - ) -> List[JsonDict]: + ) -> List[JsonMapping]: """ Return the latest presence updates that the given application service should receive. @@ -491,7 +492,7 @@ class ApplicationServicesHandler: A list of json dictionaries containing data derived from the presence events that should be sent to the given application service. """ - events: List[JsonDict] = [] + events: List[JsonMapping] = [] presence_source = self.event_sources.sources.presence from_key = await self.store.get_type_stream_id_for_appservice( service, "presence" diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index ad075497c8..8c6432035d 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py
@@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Tuple +from typing import TYPE_CHECKING, Dict, Iterable, List, Mapping, Optional, Tuple import attr from canonicaljson import encode_canonical_json @@ -31,6 +31,7 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace from synapse.types import ( JsonDict, + JsonMapping, UserID, get_domain_from_id, get_verify_key_from_cross_signing_key, @@ -272,11 +273,7 @@ class E2eKeysHandler: delay_cancellation=True, ) - ret = {"device_keys": results, "failures": failures} - - ret.update(cross_signing_keys) - - return ret + return {"device_keys": results, "failures": failures, **cross_signing_keys} @trace async def _query_devices_for_destination( @@ -408,7 +405,7 @@ class E2eKeysHandler: @cancellable async def get_cross_signing_keys_from_cache( self, query: Iterable[str], from_user_id: Optional[str] - ) -> Dict[str, Dict[str, dict]]: + ) -> Dict[str, Dict[str, JsonMapping]]: """Get cross-signing keys for users from the database Args: @@ -551,16 +548,13 @@ class E2eKeysHandler: self.config.federation.allow_device_name_lookup_over_federation ), ) - ret = {"device_keys": res} # add in the cross-signing keys cross_signing_keys = await self.get_cross_signing_keys_from_cache( device_keys_query, None ) - ret.update(cross_signing_keys) - - return ret + return {"device_keys": res, **cross_signing_keys} async def claim_local_one_time_keys( self, @@ -1127,7 +1121,7 @@ class E2eKeysHandler: user_id: str, master_key_id: str, signed_master_key: JsonDict, - stored_master_key: JsonDict, + stored_master_key: JsonMapping, devices: Dict[str, Dict[str, JsonDict]], ) -> List["SignatureListItem"]: """Check signatures of a user's master key made by their devices. @@ -1278,7 +1272,7 @@ class E2eKeysHandler: async def _get_e2e_cross_signing_verify_key( self, user_id: str, key_type: str, from_user_id: Optional[str] = None - ) -> Tuple[JsonDict, str, VerifyKey]: + ) -> Tuple[JsonMapping, str, VerifyKey]: """Fetch locally or remotely query for a cross-signing public key. First, attempt to fetch the cross-signing public key from storage. @@ -1333,7 +1327,7 @@ class E2eKeysHandler: self, user: UserID, desired_key_type: str, - ) -> Optional[Tuple[Dict[str, Any], str, VerifyKey]]: + ) -> Optional[Tuple[JsonMapping, str, VerifyKey]]: """Queries cross-signing keys for a remote user and saves them to the database Only the key specified by `key_type` will be returned, while all retrieved keys @@ -1474,7 +1468,7 @@ def _check_device_signature( user_id: str, verify_key: VerifyKey, signed_device: JsonDict, - stored_device: JsonDict, + stored_device: JsonMapping, ) -> None: """Check that a signature on a device or cross-signing key is correct and matches the copy of the device/key that we have stored. Throws an diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index eedde97ab0..0cc8e990d9 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py
@@ -1538,7 +1538,7 @@ class FederationEventHandler: logger.exception("Failed to resync device for %s", sender) async def backfill_event_id( - self, destinations: List[str], room_id: str, event_id: str + self, destinations: StrCollection, room_id: str, event_id: str ) -> PulledPduInfo: """Backfill a single event and persist it as a non-outlier which means we also pull in all of the state and auth events necessary for it. @@ -2342,6 +2342,12 @@ class FederationEventHandler: # TODO retrieve the previous state, and exclude join -> join transitions self._notifier.notify_user_joined_room(event.event_id, event.room_id) + # If this is a server ACL event, clear the cache in the storage controller. + if event.type == EventTypes.ServerACL: + self._state_storage_controller.get_server_acl_for_room.invalidate( + (event.room_id,) + ) + def _sanity_check_event(self, ev: EventBase) -> None: """ Do some early sanity checks of a received event diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 5dc76ef588..5737f8014d 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py
@@ -32,6 +32,7 @@ from synapse.storage.roommember import RoomsForUser from synapse.streams.config import PaginationConfig from synapse.types import ( JsonDict, + JsonMapping, Requester, RoomStreamToken, StreamKeyType, @@ -454,7 +455,7 @@ class InitialSyncHandler: for s in states ] - async def get_receipts() -> List[JsonDict]: + async def get_receipts() -> List[JsonMapping]: receipts = await self.store.get_linearized_receipts_for_room( room_id, to_key=now_token.receipt_key ) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index c036578a3d..44dbbf81dd 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py
@@ -1730,6 +1730,11 @@ class EventCreationHandler: event.event_id, event.room_id ) + if event.type == EventTypes.ServerACL: + self._storage_controllers.state.get_server_acl_for_room.invalidate( + (event.room_id,) + ) + await self._maybe_kick_guest_users(event, context) if event.type == EventTypes.CanonicalAlias: diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 375c7d0901..7c7cda3e95 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py
@@ -401,9 +401,9 @@ class BasePresenceHandler(abc.ABC): states, ) - for destination, host_states in hosts_to_states.items(): + for destinations, host_states in hosts_to_states: await self._federation.send_presence_to_destinations( - host_states, [destination] + host_states, destinations ) async def send_full_presence_to_users(self, user_ids: StrCollection) -> None: @@ -1000,9 +1000,9 @@ class PresenceHandler(BasePresenceHandler): list(to_federation_ping.values()), ) - for destination, states in hosts_to_states.items(): + for destinations, states in hosts_to_states: await self._federation_queue.send_presence_to_destinations( - states, [destination] + states, destinations ) @wrap_as_background_process("handle_presence_timeouts") @@ -2276,7 +2276,7 @@ async def get_interested_remotes( store: DataStore, presence_router: PresenceRouter, states: List[UserPresenceState], -) -> Dict[str, Set[UserPresenceState]]: +) -> List[Tuple[StrCollection, Collection[UserPresenceState]]]: """Given a list of presence states figure out which remote servers should be sent which. @@ -2290,23 +2290,26 @@ async def get_interested_remotes( Returns: A map from destinations to presence states to send to that destination. """ - hosts_and_states: Dict[str, Set[UserPresenceState]] = {} + hosts_and_states: List[Tuple[StrCollection, Collection[UserPresenceState]]] = [] # First we look up the rooms each user is in (as well as any explicit # subscriptions), then for each distinct room we look up the remote # hosts in those rooms. - room_ids_to_states, users_to_states = await get_interested_parties( - store, presence_router, states - ) + for state in states: + room_ids = await store.get_rooms_for_user(state.user_id) + hosts: Set[str] = set() + for room_id in room_ids: + room_hosts = await store.get_current_hosts_in_room(room_id) + hosts.update(room_hosts) + hosts_and_states.append((hosts, [state])) - for room_id, states in room_ids_to_states.items(): - hosts = await store.get_current_hosts_in_room(room_id) - for host in hosts: - hosts_and_states.setdefault(host, set()).update(states) + # Ask a presence routing module for any additional parties if one + # is loaded. + router_users_to_states = await presence_router.get_users_for_states(states) - for user_id, states in users_to_states.items(): + for user_id, user_states in router_users_to_states.items(): host = get_domain_from_id(user_id) - hosts_and_states.setdefault(host, set()).update(states) + hosts_and_states.append(([host], user_states)) return hosts_and_states diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index c7edada353..a7a29b758b 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py
@@ -19,6 +19,7 @@ from synapse.appservice import ApplicationService from synapse.streams import EventSource from synapse.types import ( JsonDict, + JsonMapping, ReadReceipt, StreamKeyType, UserID, @@ -204,15 +205,15 @@ class ReceiptsHandler: await self.federation_sender.send_read_receipt(receipt) -class ReceiptEventSource(EventSource[int, JsonDict]): +class ReceiptEventSource(EventSource[int, JsonMapping]): def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main self.config = hs.config @staticmethod def filter_out_private_receipts( - rooms: Sequence[JsonDict], user_id: str - ) -> List[JsonDict]: + rooms: Sequence[JsonMapping], user_id: str + ) -> List[JsonMapping]: """ Filters a list of serialized receipts (as returned by /sync and /initialSync) and removes private read receipts of other users. @@ -229,7 +230,7 @@ class ReceiptEventSource(EventSource[int, JsonDict]): The same as rooms, but filtered. """ - result = [] + result: List[JsonMapping] = [] # Iterate through each room's receipt content. for room in rooms: @@ -282,7 +283,7 @@ class ReceiptEventSource(EventSource[int, JsonDict]): room_ids: Iterable[str], is_guest: bool, explicit_room_id: Optional[str] = None, - ) -> Tuple[List[JsonDict], int]: + ) -> Tuple[List[JsonMapping], int]: from_key = int(from_key) to_key = self.get_current_key() @@ -301,7 +302,7 @@ class ReceiptEventSource(EventSource[int, JsonDict]): async def get_new_events_as( self, from_key: int, to_key: int, service: ApplicationService - ) -> Tuple[List[JsonDict], int]: + ) -> Tuple[List[JsonMapping], int]: """Returns a set of new read receipt events that an appservice may be interested in. diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py
index db97f7aede..9b13448cdd 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py
@@ -13,7 +13,17 @@ # limitations under the License. import enum import logging -from typing import TYPE_CHECKING, Collection, Dict, FrozenSet, Iterable, List, Optional +from typing import ( + TYPE_CHECKING, + Collection, + Dict, + FrozenSet, + Iterable, + List, + Mapping, + Optional, + Sequence, +) import attr @@ -245,7 +255,7 @@ class RelationsHandler: async def get_references_for_events( self, event_ids: Collection[str], ignored_users: FrozenSet[str] = frozenset() - ) -> Dict[str, List[_RelatedEvent]]: + ) -> Mapping[str, Sequence[_RelatedEvent]]: """Get a list of references to the given events. Args: diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 1a4d394eda..7bd42f635f 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py
@@ -235,7 +235,7 @@ class SyncResult: archived: List[ArchivedSyncResult] to_device: List[JsonDict] device_lists: DeviceListUpdates - device_one_time_keys_count: JsonDict + device_one_time_keys_count: JsonMapping device_unused_fallback_key_types: List[str] def __bool__(self) -> bool: @@ -1558,7 +1558,7 @@ class SyncHandler: logger.debug("Fetching OTK data") device_id = sync_config.device_id - one_time_keys_count: JsonDict = {} + one_time_keys_count: JsonMapping = {} unused_fallback_key_types: List[str] = [] if device_id: # TODO: We should have a way to let clients differentiate between the states of: diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 4b4227003d..bdefa7f26f 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py
@@ -26,7 +26,14 @@ from synapse.metrics.background_process_metrics import ( ) from synapse.replication.tcp.streams import TypingStream from synapse.streams import EventSource -from synapse.types import JsonDict, Requester, StrCollection, StreamKeyType, UserID +from synapse.types import ( + JsonDict, + JsonMapping, + Requester, + StrCollection, + StreamKeyType, + UserID, +) from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.metrics import Measure from synapse.util.retryutils import filter_destinations_by_retry_limiter @@ -487,7 +494,7 @@ class TypingWriterHandler(FollowerTypingHandler): raise Exception("Typing writer instance got typing info over replication") -class TypingNotificationEventSource(EventSource[int, JsonDict]): +class TypingNotificationEventSource(EventSource[int, JsonMapping]): def __init__(self, hs: "HomeServer"): self._main_store = hs.get_datastores().main self.clock = hs.get_clock() @@ -497,7 +504,7 @@ class TypingNotificationEventSource(EventSource[int, JsonDict]): # self.get_typing_handler = hs.get_typing_handler - def _make_event_for(self, room_id: str) -> JsonDict: + def _make_event_for(self, room_id: str) -> JsonMapping: typing = self.get_typing_handler()._room_typing[room_id] return { "type": EduTypes.TYPING, @@ -507,7 +514,7 @@ class TypingNotificationEventSource(EventSource[int, JsonDict]): async def get_new_events_as( self, from_key: int, service: ApplicationService - ) -> Tuple[List[JsonDict], int]: + ) -> Tuple[List[JsonMapping], int]: """Returns a set of new typing events that an appservice may be interested in. @@ -551,7 +558,7 @@ class TypingNotificationEventSource(EventSource[int, JsonDict]): room_ids: Iterable[str], is_guest: bool, explicit_room_id: Optional[str] = None, - ) -> Tuple[List[JsonDict], int]: + ) -> Tuple[List[JsonMapping], int]: with Measure(self.clock, "typing.get_new_events"): from_key = int(from_key) handler = self.get_typing_handler() diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index 5d79d31579..d9d5655c95 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py
@@ -28,8 +28,15 @@ from typing import ( overload, ) -from pydantic import BaseModel, MissingError, PydanticValueError, ValidationError -from pydantic.error_wrappers import ErrorWrapper +from synapse._pydantic_compat import HAS_PYDANTIC_V2 + +if TYPE_CHECKING or HAS_PYDANTIC_V2: + from pydantic.v1 import BaseModel, MissingError, PydanticValueError, ValidationError + from pydantic.v1.error_wrappers import ErrorWrapper +else: + from pydantic import BaseModel, MissingError, PydanticValueError, ValidationError + from pydantic.error_wrappers import ErrorWrapper + from typing_extensions import Literal from twisted.web.server import Request diff --git a/synapse/media/_base.py b/synapse/media/_base.py
index 20cb8b9010..80c448de2b 100644 --- a/synapse/media/_base.py +++ b/synapse/media/_base.py
@@ -50,6 +50,39 @@ TEXT_CONTENT_TYPES = [ "text/xml", ] +# A list of all content types that are "safe" to be rendered inline in a browser. +INLINE_CONTENT_TYPES = [ + "text/css", + "text/plain", + "text/csv", + "application/json", + "application/ld+json", + # We allow some media files deemed as safe, which comes from the matrix-react-sdk. + # https://github.com/matrix-org/matrix-react-sdk/blob/a70fcfd0bcf7f8c85986da18001ea11597989a7c/src/utils/blobs.ts#L51 + # SVGs are *intentionally* omitted. + "image/jpeg", + "image/gif", + "image/png", + "image/apng", + "image/webp", + "image/avif", + "video/mp4", + "video/webm", + "video/ogg", + "video/quicktime", + "audio/mp4", + "audio/webm", + "audio/aac", + "audio/mpeg", + "audio/ogg", + "audio/wave", + "audio/wav", + "audio/x-wav", + "audio/x-pn-wav", + "audio/flac", + "audio/x-flac", +] + def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]: """Parses the server name, media ID and optional file name from the request URI @@ -153,8 +186,13 @@ def add_file_headers( request.setHeader(b"Content-Type", content_type.encode("UTF-8")) - # Use a Content-Disposition of attachment to force download of media. - disposition = "attachment" + # A strict subset of content types is allowed to be inlined so that they may + # be viewed directly in a browser. Other file types are forced to be downloads. + if media_type.lower() in INLINE_CONTENT_TYPES: + disposition = "inline" + else: + disposition = "attachment" + if upload_name: # RFC6266 section 4.1 [1] defines both `filename` and `filename*`. # diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 554634579e..14784312dc 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -131,7 +131,7 @@ class BulkPushRuleEvaluator: async def _get_rules_for_event( self, event: EventBase, - ) -> Dict[str, FilteredPushRules]: + ) -> Mapping[str, FilteredPushRules]: """Get the push rules for all users who may need to be notified about the event. diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index ca8a76f77c..f4f2b29e96 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py
@@ -205,6 +205,12 @@ class ReplicationDataHandler: self.notifier.notify_user_joined_room( row.data.event_id, row.data.room_id ) + + # If this is a server ACL event, clear the cache in the storage controller. + if row.data.type == EventTypes.ServerACL: + self._state_storage_controller.get_server_acl_for_room.invalidate( + (row.data.room_id,) + ) elif stream_name == UnPartialStatedRoomStream.NAME: for row in rows: assert isinstance(row, UnPartialStatedRoomStreamRow) @@ -333,7 +339,7 @@ class ReplicationDataHandler: try: await make_deferred_yieldable(deferred) except defer.TimeoutError: - logger.error( + logger.warning( "Timed out waiting for repl stream %r to reach %s (%s)" "; currently at: %s", stream_name, diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index 7d0b4b55a0..e42dade246 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py
@@ -16,7 +16,6 @@ # limitations under the License. import logging -import platform from http import HTTPStatus from typing import TYPE_CHECKING, Optional, Tuple @@ -107,10 +106,7 @@ class VersionServlet(RestServlet): PATTERNS = admin_patterns("/server_version$") def __init__(self, hs: "HomeServer"): - self.res = { - "server_version": SYNAPSE_VERSION, - "python_version": platform.python_version(), - } + self.res = {"server_version": SYNAPSE_VERSION} def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: return HTTPStatus.OK, self.res diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py
index 49cd0805fd..e74a87af4d 100644 --- a/synapse/rest/client/account.py +++ b/synapse/rest/client/account.py
@@ -18,7 +18,12 @@ import random from typing import TYPE_CHECKING, List, Optional, Tuple from urllib.parse import urlparse -from pydantic import StrictBool, StrictStr, constr +from synapse._pydantic_compat import HAS_PYDANTIC_V2 + +if TYPE_CHECKING or HAS_PYDANTIC_V2: + from pydantic.v1 import StrictBool, StrictStr, constr +else: + from pydantic import StrictBool, StrictStr, constr from typing_extensions import Literal from twisted.web.server import Request diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py
index 925f037743..80ae937921 100644 --- a/synapse/rest/client/devices.py +++ b/synapse/rest/client/devices.py
@@ -17,7 +17,12 @@ import logging from http import HTTPStatus from typing import TYPE_CHECKING, List, Optional, Tuple -from pydantic import Extra, StrictStr +from synapse._pydantic_compat import HAS_PYDANTIC_V2 + +if TYPE_CHECKING or HAS_PYDANTIC_V2: + from pydantic.v1 import Extra, StrictStr +else: + from pydantic import Extra, StrictStr from synapse.api import errors from synapse.api.errors import NotFoundError, SynapseError, UnrecognizedRequestError diff --git a/synapse/rest/client/directory.py b/synapse/rest/client/directory.py
index 570bb52747..82944ca711 100644 --- a/synapse/rest/client/directory.py +++ b/synapse/rest/client/directory.py
@@ -15,7 +15,13 @@ import logging from typing import TYPE_CHECKING, List, Optional, Tuple -from pydantic import StrictStr +from synapse._pydantic_compat import HAS_PYDANTIC_V2 + +if TYPE_CHECKING or HAS_PYDANTIC_V2: + from pydantic.v1 import StrictStr +else: + from pydantic import StrictStr + from typing_extensions import Literal from twisted.web.server import Request diff --git a/synapse/rest/client/filter.py b/synapse/rest/client/filter.py
index 5da1e511a2..b5879496db 100644 --- a/synapse/rest/client/filter.py +++ b/synapse/rest/client/filter.py
@@ -19,7 +19,7 @@ from synapse.api.errors import AuthError, NotFoundError, StoreError, SynapseErro from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.site import SynapseRequest -from synapse.types import JsonDict, UserID +from synapse.types import JsonDict, JsonMapping, UserID from ._base import client_patterns, set_timeline_upper_limit @@ -41,7 +41,7 @@ class GetFilterRestServlet(RestServlet): async def on_GET( self, request: SynapseRequest, user_id: str, filter_id: str - ) -> Tuple[int, JsonDict]: + ) -> Tuple[int, JsonMapping]: target_user = UserID.from_string(user_id) requester = await self.auth.get_user_by_req(request) diff --git a/synapse/rest/client/models.py b/synapse/rest/client/models.py
index 3d7940b0fc..880f79473c 100644 --- a/synapse/rest/client/models.py +++ b/synapse/rest/client/models.py
@@ -13,7 +13,12 @@ # limitations under the License. from typing import TYPE_CHECKING, Dict, Optional -from pydantic import Extra, StrictInt, StrictStr, constr, validator +from synapse._pydantic_compat import HAS_PYDANTIC_V2 + +if TYPE_CHECKING or HAS_PYDANTIC_V2: + from pydantic.v1 import Extra, StrictInt, StrictStr, constr, validator +else: + from pydantic import Extra, StrictInt, StrictStr, constr, validator from synapse.rest.models import RequestBodyModel from synapse.util.threepids import validate_email diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index 0aaa838d04..48c47058db 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -16,7 +16,13 @@ import logging import re from typing import TYPE_CHECKING, Dict, Mapping, Optional, Set, Tuple -from pydantic import Extra, StrictInt, StrictStr +from synapse._pydantic_compat import HAS_PYDANTIC_V2 + +if TYPE_CHECKING or HAS_PYDANTIC_V2: + from pydantic.v1 import Extra, StrictInt, StrictStr +else: + from pydantic import StrictInt, StrictStr, Extra + from signedjson.sign import sign_json from twisted.web.server import Request diff --git a/synapse/rest/models.py b/synapse/rest/models.py
index ac39cda8e5..de354a2135 100644 --- a/synapse/rest/models.py +++ b/synapse/rest/models.py
@@ -1,4 +1,24 @@ -from pydantic import BaseModel, Extra +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from synapse._pydantic_compat import HAS_PYDANTIC_V2 + +if TYPE_CHECKING or HAS_PYDANTIC_V2: + from pydantic.v1 import BaseModel, Extra +else: + from pydantic import BaseModel, Extra class RequestBodyModel(BaseModel): diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index 1752f95db8..b2e63aed1e 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py
@@ -23,7 +23,6 @@ from typing import ( Generator, Iterable, List, - Mapping, Optional, Sequence, Set, @@ -269,7 +268,7 @@ async def _get_power_level_for_sender( async def _get_auth_chain_difference( room_id: str, - state_sets: Sequence[Mapping[Any, str]], + state_sets: Sequence[StateMap[str]], unpersisted_events: Dict[str, EventBase], state_res_store: StateResolutionStore, ) -> Set[str]: @@ -405,7 +404,7 @@ def _seperate( # mypy doesn't understand that discarding None above means that conflicted # state is StateMap[Set[str]], not StateMap[Set[Optional[Str]]]. - return unconflicted_state, conflicted_state # type: ignore + return unconflicted_state, conflicted_state # type: ignore[return-value] def _is_power_event(event: EventBase) -> bool: diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 99ebd96f84..12829d3d7d 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py
@@ -31,8 +31,8 @@ from typing import ( ) import attr -from pydantic import BaseModel +from synapse._pydantic_compat import HAS_PYDANTIC_V2 from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.engines import PostgresEngine from synapse.storage.types import Connection, Cursor @@ -41,6 +41,11 @@ from synapse.util import Clock, json_encoder from . import engines +if TYPE_CHECKING or HAS_PYDANTIC_V2: + from pydantic.v1 import BaseModel +else: + from pydantic import BaseModel + if TYPE_CHECKING: from synapse.server import HomeServer from synapse.storage.database import DatabasePool, LoggingTransaction diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py
index 278c7832ba..46957723a1 100644 --- a/synapse/storage/controllers/state.py +++ b/synapse/storage/controllers/state.py
@@ -37,6 +37,7 @@ from synapse.storage.util.partial_state_events_tracker import ( PartialCurrentStateTracker, PartialStateEventsTracker, ) +from synapse.synapse_rust.acl import ServerAclEvaluator from synapse.types import MutableStateMap, StateMap, get_domain_from_id from synapse.types.state import StateFilter from synapse.util.async_helpers import Linearizer @@ -501,6 +502,31 @@ class StateStorageController: return event.content.get("alias") + @cached() + async def get_server_acl_for_room( + self, room_id: str + ) -> Optional[ServerAclEvaluator]: + """Get the server ACL evaluator for room, if any + + This does up-front parsing of the content to ignore bad data and pre-compile + regular expressions. + + Args: + room_id: The room ID + + Returns: + The server ACL evaluator, if any + """ + + acl_event = await self.get_current_state_event( + room_id, EventTypes.ServerACL, "" + ) + + if not acl_event: + return None + + return server_acl_evaluator_from_event(acl_event) + @trace @tag_args async def get_current_state_deltas( @@ -582,7 +608,7 @@ class StateStorageController: @trace @tag_args - async def get_current_hosts_in_room_ordered(self, room_id: str) -> List[str]: + async def get_current_hosts_in_room_ordered(self, room_id: str) -> Tuple[str, ...]: """Get current hosts in room based on current state. Blocks until we have full state for the given room. This only happens for rooms @@ -760,3 +786,36 @@ class StateStorageController: cache.state_group = object() return frozenset(cache.hosts_to_joined_users) + + +def server_acl_evaluator_from_event(acl_event: EventBase) -> "ServerAclEvaluator": + """ + Create a ServerAclEvaluator from a m.room.server_acl event's content. + + This does up-front parsing of the content to ignore bad data. It then creates + the ServerAclEvaluator which will pre-compile regular expressions from the globs. + """ + + # first of all, parse if literal IPs are blocked. + allow_ip_literals = acl_event.content.get("allow_ip_literals", True) + if not isinstance(allow_ip_literals, bool): + logger.warning("Ignoring non-bool allow_ip_literals flag") + allow_ip_literals = True + + # next, parse the deny list by ignoring any non-strings. + deny = acl_event.content.get("deny", []) + if not isinstance(deny, (list, tuple)): + logger.warning("Ignoring non-list deny ACL %s", deny) + deny = [] + else: + deny = [s for s in deny if isinstance(s, str)] + + # then the allow list. + allow = acl_event.content.get("allow", []) + if not isinstance(allow, (list, tuple)): + logger.warning("Ignoring non-list allow ACL %s", allow) + allow = [] + else: + allow = [s for s in allow if isinstance(s, str)] + + return ServerAclEvaluator(allow_ip_literals, allow, deny) diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 697bc5651c..ca894edd5a 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py
@@ -361,19 +361,7 @@ class LoggingTransaction: @property def description( self, - ) -> Optional[ - Sequence[ - Tuple[ - str, - Optional[Any], - Optional[int], - Optional[int], - Optional[int], - Optional[int], - Optional[int], - ] - ] - ]: + ) -> Optional[Sequence[Any]]: return self.txn.description def execute_batch(self, sql: str, args: Iterable[Iterable[Any]]) -> None: diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 484db175d0..0553a0621a 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py
@@ -45,7 +45,7 @@ from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.storage.types import Cursor from synapse.storage.util.sequence import build_sequence_generator -from synapse.types import DeviceListUpdates, JsonDict +from synapse.types import DeviceListUpdates, JsonMapping from synapse.util import json_encoder from synapse.util.caches.descriptors import _CacheContext, cached @@ -268,8 +268,8 @@ class ApplicationServiceTransactionWorkerStore( self, service: ApplicationService, events: Sequence[EventBase], - ephemeral: List[JsonDict], - to_device_messages: List[JsonDict], + ephemeral: List[JsonMapping], + to_device_messages: List[JsonMapping], one_time_keys_count: TransactionOneTimeKeysCount, unused_fallback_keys: TransactionUnusedFallbackKeys, device_list_summary: DeviceListUpdates, diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 70faf4b1ec..df596f35f9 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py
@@ -55,7 +55,12 @@ from synapse.storage.util.id_generators import ( AbstractStreamIdGenerator, StreamIdGenerator, ) -from synapse.types import JsonDict, StrCollection, get_verify_key_from_cross_signing_key +from synapse.types import ( + JsonDict, + JsonMapping, + StrCollection, + get_verify_key_from_cross_signing_key, +) from synapse.util import json_decoder, json_encoder from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.lrucache import LruCache @@ -746,7 +751,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): @cancellable async def get_user_devices_from_cache( self, user_ids: Set[str], user_and_device_ids: List[Tuple[str, str]] - ) -> Tuple[Set[str], Dict[str, Mapping[str, JsonDict]]]: + ) -> Tuple[Set[str], Dict[str, Mapping[str, JsonMapping]]]: """Get the devices (and keys if any) for remote users from the cache. Args: @@ -766,13 +771,13 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): user_ids_not_in_cache = unique_user_ids - user_ids_in_cache # First fetch all the users which all devices are to be returned. - results: Dict[str, Mapping[str, JsonDict]] = {} + results: Dict[str, Mapping[str, JsonMapping]] = {} for user_id in user_ids: if user_id in user_ids_in_cache: results[user_id] = await self.get_cached_devices_for_user(user_id) # Then fetch all device-specific requests, but skip users we've already # fetched all devices for. - device_specific_results: Dict[str, Dict[str, JsonDict]] = {} + device_specific_results: Dict[str, Dict[str, JsonMapping]] = {} for user_id, device_id in user_and_device_ids: if user_id in user_ids_in_cache and user_id not in user_ids: device = await self._get_cached_user_device(user_id, device_id) @@ -801,7 +806,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): return user_ids_in_cache @cached(num_args=2, tree=True) - async def _get_cached_user_device(self, user_id: str, device_id: str) -> JsonDict: + async def _get_cached_user_device( + self, user_id: str, device_id: str + ) -> JsonMapping: content = await self.db_pool.simple_select_one_onecol( table="device_lists_remote_cache", keyvalues={"user_id": user_id, "device_id": device_id}, @@ -811,7 +818,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): return db_to_json(content) @cached() - async def get_cached_devices_for_user(self, user_id: str) -> Mapping[str, JsonDict]: + async def get_cached_devices_for_user( + self, user_id: str + ) -> Mapping[str, JsonMapping]: devices = await self.db_pool.simple_select_list( table="device_lists_remote_cache", keyvalues={"user_id": user_id}, @@ -1042,7 +1051,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): ) async def get_device_list_last_stream_id_for_remotes( self, user_ids: Iterable[str] - ) -> Dict[str, Optional[str]]: + ) -> Mapping[str, Optional[str]]: rows = await self.db_pool.simple_select_many_batch( table="device_lists_remote_extremeties", column="user_id", diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index b49dea577c..89fac23f93 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -52,7 +52,7 @@ from synapse.storage.database import ( from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import StreamIdGenerator -from synapse.types import JsonDict +from synapse.types import JsonDict, JsonMapping from synapse.util import json_decoder, json_encoder from synapse.util.caches.descriptors import cached, cachedList from synapse.util.cancellation import cancellable @@ -125,7 +125,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker async def get_e2e_device_keys_for_federation_query( self, user_id: str - ) -> Tuple[int, List[JsonDict]]: + ) -> Tuple[int, Sequence[JsonMapping]]: """Get all devices (with any device keys) for a user Returns: @@ -174,7 +174,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker @cached(iterable=True) async def _get_e2e_device_keys_for_federation_query_inner( self, user_id: str - ) -> List[JsonDict]: + ) -> Sequence[JsonMapping]: """Get all devices (with any device keys) for a user""" devices = await self.get_e2e_device_keys_and_signatures([(user_id, None)]) @@ -578,7 +578,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker @cached(max_entries=10000) async def count_e2e_one_time_keys( self, user_id: str, device_id: str - ) -> Dict[str, int]: + ) -> Mapping[str, int]: """Count the number of one time keys the server has for a device Returns: A mapping from algorithm to number of keys for that algorithm. @@ -812,7 +812,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker async def get_e2e_cross_signing_key( self, user_id: str, key_type: str, from_user_id: Optional[str] = None - ) -> Optional[JsonDict]: + ) -> Optional[JsonMapping]: """Returns a user's cross-signing key. Args: @@ -833,7 +833,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker return user_keys.get(key_type) @cached(num_args=1) - def _get_bare_e2e_cross_signing_keys(self, user_id: str) -> Mapping[str, JsonDict]: + def _get_bare_e2e_cross_signing_keys( + self, user_id: str + ) -> Mapping[str, JsonMapping]: """Dummy function. Only used to make a cache for _get_bare_e2e_cross_signing_keys_bulk. """ @@ -846,7 +848,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker ) async def _get_bare_e2e_cross_signing_keys_bulk( self, user_ids: Iterable[str] - ) -> Dict[str, Optional[Mapping[str, JsonDict]]]: + ) -> Mapping[str, Optional[Mapping[str, JsonMapping]]]: """Returns the cross-signing keys for a set of users. The output of this function should be passed to _get_e2e_cross_signing_signatures_txn if the signatures for the calling user need to be fetched. @@ -860,15 +862,12 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker their user ID will map to None. """ - result = await self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_bare_e2e_cross_signing_keys_bulk", self._get_bare_e2e_cross_signing_keys_bulk_txn, user_ids, ) - # The `Optional` comes from the `@cachedList` decorator. - return cast(Dict[str, Optional[Mapping[str, JsonDict]]], result) - def _get_bare_e2e_cross_signing_keys_bulk_txn( self, txn: LoggingTransaction, @@ -1026,7 +1025,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker @cancellable async def get_e2e_cross_signing_keys_bulk( self, user_ids: List[str], from_user_id: Optional[str] = None - ) -> Dict[str, Optional[Mapping[str, JsonDict]]]: + ) -> Mapping[str, Optional[Mapping[str, JsonMapping]]]: """Returns the cross-signing keys for a set of users. Args: @@ -1043,7 +1042,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker if from_user_id: result = cast( - Dict[str, Optional[Mapping[str, JsonDict]]], + Dict[str, Optional[Mapping[str, JsonMapping]]], await self.db_pool.runInteraction( "get_e2e_cross_signing_signatures", self._get_e2e_cross_signing_signatures_txn, diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 943666ed4f..8737a1370e 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py
@@ -24,6 +24,7 @@ from typing import ( Dict, Iterable, List, + Mapping, MutableMapping, Optional, Set, @@ -1633,7 +1634,7 @@ class EventsWorkerStore(SQLBaseStore): self, room_id: str, event_ids: Collection[str], - ) -> Dict[str, bool]: + ) -> Mapping[str, bool]: """Helper for have_seen_events Returns: @@ -2329,7 +2330,7 @@ class EventsWorkerStore(SQLBaseStore): @cachedList(cached_method_name="is_partial_state_event", list_name="event_ids") async def get_partial_state_events( self, event_ids: Collection[str] - ) -> Dict[str, bool]: + ) -> Mapping[str, bool]: """Checks which of the given events have partial state Args: diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py
index 047de6283a..7d94685caf 100644 --- a/synapse/storage/databases/main/filtering.py +++ b/synapse/storage/databases/main/filtering.py
@@ -25,7 +25,7 @@ from synapse.storage.database import ( LoggingTransaction, ) from synapse.storage.engines import PostgresEngine -from synapse.types import JsonDict, UserID +from synapse.types import JsonDict, JsonMapping, UserID from synapse.util.caches.descriptors import cached if TYPE_CHECKING: @@ -145,7 +145,7 @@ class FilteringWorkerStore(SQLBaseStore): @cached(num_args=2) async def get_user_filter( self, user_id: UserID, filter_id: Union[int, str] - ) -> JsonDict: + ) -> JsonMapping: # filter_id is BIGINT UNSIGNED, so if it isn't a number, fail # with a coherent error message rather than 500 M_UNKNOWN. try: diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py
index 41563371dc..889c578b9c 100644 --- a/synapse/storage/databases/main/keys.py +++ b/synapse/storage/databases/main/keys.py
@@ -16,7 +16,7 @@ import itertools import json import logging -from typing import Dict, Iterable, Optional, Tuple +from typing import Dict, Iterable, Mapping, Optional, Tuple from canonicaljson import encode_canonical_json from signedjson.key import decode_verify_key_bytes @@ -130,7 +130,7 @@ class KeyStore(CacheInvalidationWorkerStore): ) async def get_server_keys_json( self, server_name_and_key_ids: Iterable[Tuple[str, str]] - ) -> Dict[Tuple[str, str], FetchKeyResult]: + ) -> Mapping[Tuple[str, str], FetchKeyResult]: """ Args: server_name_and_key_ids: @@ -200,7 +200,7 @@ class KeyStore(CacheInvalidationWorkerStore): ) async def get_server_keys_json_for_remote( self, server_name: str, key_ids: Iterable[str] - ) -> Dict[str, Optional[FetchKeyResultForRemote]]: + ) -> Mapping[str, Optional[FetchKeyResultForRemote]]: """Fetch the cached keys for the given server/key IDs. If we have multiple entries for a given key ID, returns the most recent. diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index b51d20ac26..194b4e031f 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py
@@ -11,7 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, cast +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Iterable, + List, + Mapping, + Optional, + Tuple, + cast, +) from synapse.api.presence import PresenceState, UserPresenceState from synapse.replication.tcp.streams import PresenceStream @@ -249,7 +259,7 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore) ) async def get_presence_for_users( self, user_ids: Iterable[str] - ) -> Dict[str, UserPresenceState]: + ) -> Mapping[str, UserPresenceState]: rows = await self.db_pool.simple_select_many_batch( table="presence_stream", column="user_id", diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index bec0dc2afe..923166974c 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py
@@ -88,6 +88,7 @@ def _load_rules( msc1767_enabled=experimental_config.msc1767_enabled, msc3664_enabled=experimental_config.msc3664_enabled, msc3381_polls_enabled=experimental_config.msc3381_polls_enabled, + msc4028_push_encrypted_events=experimental_config.msc4028_push_encrypted_events, ) return filtered_rules @@ -216,7 +217,7 @@ class PushRulesWorkerStore( @cachedList(cached_method_name="get_push_rules_for_user", list_name="user_ids") async def bulk_get_push_rules( self, user_ids: Collection[str] - ) -> Dict[str, FilteredPushRules]: + ) -> Mapping[str, FilteredPushRules]: if not user_ids: return {} diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index a074c43989..0231f9407b 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py
@@ -43,7 +43,7 @@ from synapse.storage.util.id_generators import ( MultiWriterIdGenerator, StreamIdGenerator, ) -from synapse.types import JsonDict +from synapse.types import JsonDict, JsonMapping from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -218,7 +218,7 @@ class ReceiptsWorkerStore(SQLBaseStore): @cached() async def _get_receipts_for_user_with_orderings( self, user_id: str, receipt_type: str - ) -> JsonDict: + ) -> JsonMapping: """ Fetch receipts for all rooms that the given user is joined to. @@ -258,7 +258,7 @@ class ReceiptsWorkerStore(SQLBaseStore): async def get_linearized_receipts_for_rooms( self, room_ids: Iterable[str], to_key: int, from_key: Optional[int] = None - ) -> List[dict]: + ) -> List[JsonMapping]: """Get receipts for multiple rooms for sending to clients. Args: @@ -287,7 +287,7 @@ class ReceiptsWorkerStore(SQLBaseStore): async def get_linearized_receipts_for_room( self, room_id: str, to_key: int, from_key: Optional[int] = None - ) -> Sequence[JsonDict]: + ) -> Sequence[JsonMapping]: """Get receipts for a single room for sending to clients. Args: @@ -310,7 +310,7 @@ class ReceiptsWorkerStore(SQLBaseStore): @cached(tree=True) async def _get_linearized_receipts_for_room( self, room_id: str, to_key: int, from_key: Optional[int] = None - ) -> Sequence[JsonDict]: + ) -> Sequence[JsonMapping]: """See get_linearized_receipts_for_room""" def f(txn: LoggingTransaction) -> List[Dict[str, Any]]: @@ -353,7 +353,7 @@ class ReceiptsWorkerStore(SQLBaseStore): ) async def _get_linearized_receipts_for_rooms( self, room_ids: Collection[str], to_key: int, from_key: Optional[int] = None - ) -> Dict[str, Sequence[JsonDict]]: + ) -> Mapping[str, Sequence[JsonMapping]]: if not room_ids: return {} @@ -415,7 +415,7 @@ class ReceiptsWorkerStore(SQLBaseStore): ) async def get_linearized_receipts_for_all_rooms( self, to_key: int, from_key: Optional[int] = None - ) -> Mapping[str, JsonDict]: + ) -> Mapping[str, JsonMapping]: """Get receipts for all rooms between two stream_ids, up to a limit of the latest 100 read receipts. diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 96908f14ba..b67f780c10 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py
@@ -465,7 +465,7 @@ class RelationsWorkerStore(SQLBaseStore): @cachedList(cached_method_name="get_references_for_event", list_name="event_ids") async def get_references_for_events( self, event_ids: Collection[str] - ) -> Mapping[str, Optional[List[_RelatedEvent]]]: + ) -> Mapping[str, Optional[Sequence[_RelatedEvent]]]: """Get a list of references to the given events. Args: @@ -519,7 +519,7 @@ class RelationsWorkerStore(SQLBaseStore): @cachedList(cached_method_name="get_applicable_edit", list_name="event_ids") async def get_applicable_edits( self, event_ids: Collection[str] - ) -> Dict[str, Optional[EventBase]]: + ) -> Mapping[str, Optional[EventBase]]: """Get the most recent edit (if any) that has happened for the given events. @@ -605,7 +605,7 @@ class RelationsWorkerStore(SQLBaseStore): @cachedList(cached_method_name="get_thread_summary", list_name="event_ids") async def get_thread_summaries( self, event_ids: Collection[str] - ) -> Dict[str, Optional[Tuple[int, EventBase]]]: + ) -> Mapping[str, Optional[Tuple[int, EventBase]]]: """Get the number of threaded replies and the latest reply (if any) for the given events. Args: @@ -779,7 +779,7 @@ class RelationsWorkerStore(SQLBaseStore): @cachedList(cached_method_name="get_thread_participated", list_name="event_ids") async def get_threads_participated( self, event_ids: Collection[str], user_id: str - ) -> Dict[str, bool]: + ) -> Mapping[str, bool]: """Get whether the requesting user participated in the given threads. This is separate from get_thread_summaries since that can be cached across @@ -931,7 +931,7 @@ class RelationsWorkerStore(SQLBaseStore): room_id: str, limit: int = 5, from_token: Optional[ThreadsNextBatch] = None, - ) -> Tuple[List[str], Optional[ThreadsNextBatch]]: + ) -> Tuple[Sequence[str], Optional[ThreadsNextBatch]]: """Get a list of thread IDs, ordered by topological ordering of their latest reply. diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index fff259f74c..3755773faa 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py
@@ -191,7 +191,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): ) async def get_subset_users_in_room_with_profiles( self, room_id: str, user_ids: Collection[str] - ) -> Dict[str, ProfileInfo]: + ) -> Mapping[str, ProfileInfo]: """Get a mapping from user ID to profile information for a list of users in a given room. @@ -676,7 +676,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): ) async def _get_rooms_for_users( self, user_ids: Collection[str] - ) -> Dict[str, FrozenSet[str]]: + ) -> Mapping[str, FrozenSet[str]]: """A batched version of `get_rooms_for_user`. Returns: @@ -881,7 +881,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): ) async def _get_user_ids_from_membership_event_ids( self, event_ids: Iterable[str] - ) -> Dict[str, Optional[str]]: + ) -> Mapping[str, Optional[str]]: """For given set of member event_ids check if they point to a join event. @@ -984,7 +984,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): ) @cached(iterable=True, max_entries=10000) - async def get_current_hosts_in_room_ordered(self, room_id: str) -> List[str]: + async def get_current_hosts_in_room_ordered(self, room_id: str) -> Tuple[str, ...]: """ Get current hosts in room based on current state. @@ -1013,12 +1013,14 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): # `get_users_in_room` rather than funky SQL. domains = await self.get_current_hosts_in_room(room_id) - return list(domains) + return tuple(domains) # For PostgreSQL we can use a regex to pull out the domains from the # joined users in `current_state_events` via regex. - def get_current_hosts_in_room_ordered_txn(txn: LoggingTransaction) -> List[str]: + def get_current_hosts_in_room_ordered_txn( + txn: LoggingTransaction, + ) -> Tuple[str, ...]: # Returns a list of servers currently joined in the room sorted by # longest in the room first (aka. with the lowest depth). The # heuristic of sorting by servers who have been in the room the @@ -1043,7 +1045,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): """ txn.execute(sql, (room_id,)) # `server_domain` will be `NULL` for malformed MXIDs with no colons. - return [d for d, in txn if d is not None] + return tuple(d for d, in txn if d is not None) return await self.db_pool.runInteraction( "get_current_hosts_in_room_ordered", get_current_hosts_in_room_ordered_txn @@ -1191,7 +1193,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): ) async def get_membership_from_event_ids( self, member_event_ids: Iterable[str] - ) -> Dict[str, Optional[EventIdMembership]]: + ) -> Mapping[str, Optional[EventIdMembership]]: """Get user_id and membership of a set of event IDs. Returns: diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index ebb2ae964f..5eaaff5b68 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py
@@ -14,7 +14,17 @@ # limitations under the License. import collections.abc import logging -from typing import TYPE_CHECKING, Any, Collection, Dict, Iterable, Optional, Set, Tuple +from typing import ( + TYPE_CHECKING, + Any, + Collection, + Dict, + Iterable, + Mapping, + Optional, + Set, + Tuple, +) import attr @@ -372,7 +382,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): ) async def _get_state_group_for_events( self, event_ids: Collection[str] - ) -> Dict[str, int]: + ) -> Mapping[str, int]: """Returns mapping event_id -> state_group. Raises: diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index efd21b5bfc..8f70eff809 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py
@@ -14,7 +14,7 @@ import logging from enum import Enum -from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, cast +from typing import TYPE_CHECKING, Iterable, List, Mapping, Optional, Tuple, cast import attr from canonicaljson import encode_canonical_json @@ -210,7 +210,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore): ) async def get_destination_retry_timings_batch( self, destinations: StrCollection - ) -> Dict[str, Optional[DestinationRetryTimings]]: + ) -> Mapping[str, Optional[DestinationRetryTimings]]: rows = await self.db_pool.simple_select_many_batch( table="destinations", iterable=destinations, diff --git a/synapse/storage/databases/main/user_erasure_store.py b/synapse/storage/databases/main/user_erasure_store.py
index f79006533f..06fcbe5e54 100644 --- a/synapse/storage/databases/main/user_erasure_store.py +++ b/synapse/storage/databases/main/user_erasure_store.py
@@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Iterable +from typing import Iterable, Mapping from synapse.storage.database import LoggingTransaction from synapse.storage.databases.main import CacheInvalidationWorkerStore @@ -40,7 +40,7 @@ class UserErasureWorkerStore(CacheInvalidationWorkerStore): return bool(result) @cachedList(cached_method_name="is_user_erased", list_name="user_ids") - async def are_users_erased(self, user_ids: Iterable[str]) -> Dict[str, bool]: + async def are_users_erased(self, user_ids: Iterable[str]) -> Mapping[str, bool]: """ Checks which users in a list have requested erasure diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py
index 5b8ba436d4..6ff533a129 100644 --- a/synapse/storage/databases/state/bg_updates.py +++ b/synapse/storage/databases/state/bg_updates.py
@@ -94,6 +94,18 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore): groups: List[int], state_filter: Optional[StateFilter] = None, ) -> Mapping[int, StateMap[str]]: + """ + Given a number of state groups, fetch the latest state for each group. + + Args: + txn: The transaction object. + groups: The given state groups that you want to fetch the latest state for. + state_filter: The state filter to apply the state we fetch state from the database. + + Returns: + Map from state_group to a StateMap at that point. + """ + state_filter = state_filter or StateFilter.all() results: Dict[int, MutableStateMap[str]] = {group: {} for group in groups} @@ -206,8 +218,10 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore): if where_clause: where_clause = " AND (%s)" % (where_clause,) - # We don't use WITH RECURSIVE on sqlite3 as there are distributions - # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) + # XXX: We could `WITH RECURSIVE` here since it's supported on SQLite 3.8.3 + # or higher and our minimum supported version is greater than that. + # + # We just haven't put in the time to refactor this. for group in groups: next_group: Optional[int] = group diff --git a/synapse/storage/types.py b/synapse/storage/types.py
index 34ac807530..afaeef9a5a 100644 --- a/synapse/storage/types.py +++ b/synapse/storage/types.py
@@ -53,22 +53,10 @@ class Cursor(Protocol): @property def description( self, - ) -> Optional[ - Sequence[ - # Note that this is an approximate typing based on sqlite3 and other - # drivers, and may not be entirely accurate. - # FWIW, the DBAPI 2 spec is: https://peps.python.org/pep-0249/#description - Tuple[ - str, - Optional[Any], - Optional[int], - Optional[int], - Optional[int], - Optional[int], - Optional[int], - ] - ] - ]: + ) -> Optional[Sequence[Any]]: + # At the time of writing, Synapse only assumes that `column[0]: str` for each + # `column in description`. Since this is hard to express in the type system, and + # as this is rarely used in Synapse, we deem `column: Any` good enough. ... @property