diff --git a/synapse/_pydantic_compat.py b/synapse/_pydantic_compat.py
new file mode 100644
index 0000000000..ddff72afa1
--- /dev/null
+++ b/synapse/_pydantic_compat.py
@@ -0,0 +1,26 @@
+# Copyright 2023 Maxwell G <maxwell@gtmx.me>
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from packaging.version import Version
+
+try:
+ from pydantic import __version__ as pydantic_version
+except ImportError:
+ import importlib.metadata
+
+ pydantic_version = importlib.metadata.version("pydantic")
+
+HAS_PYDANTIC_V2: bool = Version(pydantic_version).major == 2
+
+__all__ = ("HAS_PYDANTIC_V2",)
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index 0995ecbe83..74ee8e9f3f 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -37,7 +37,7 @@ from synapse.api.constants import EduTypes, EventContentFields
from synapse.api.errors import SynapseError
from synapse.api.presence import UserPresenceState
from synapse.events import EventBase, relation_from_event
-from synapse.types import JsonDict, RoomID, UserID
+from synapse.types import JsonDict, JsonMapping, RoomID, UserID
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -191,7 +191,7 @@ FilterEvent = TypeVar("FilterEvent", EventBase, UserPresenceState, JsonDict)
class FilterCollection:
- def __init__(self, hs: "HomeServer", filter_json: JsonDict):
+ def __init__(self, hs: "HomeServer", filter_json: JsonMapping):
self._filter_json = filter_json
room_filter_json = self._filter_json.get("room", {})
@@ -219,7 +219,7 @@ class FilterCollection:
def __repr__(self) -> str:
return "<FilterCollection %s>" % (json.dumps(self._filter_json),)
- def get_filter_json(self) -> JsonDict:
+ def get_filter_json(self) -> JsonMapping:
return self._filter_json
def timeline_limit(self) -> int:
@@ -313,7 +313,7 @@ class FilterCollection:
class Filter:
- def __init__(self, hs: "HomeServer", filter_json: JsonDict):
+ def __init__(self, hs: "HomeServer", filter_json: JsonMapping):
self._hs = hs
self._store = hs.get_datastores().main
self.filter_json = filter_json
diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py
index 2260a8f589..6f4aa53c93 100644
--- a/synapse/appservice/__init__.py
+++ b/synapse/appservice/__init__.py
@@ -23,7 +23,7 @@ from netaddr import IPSet
from synapse.api.constants import EventTypes
from synapse.events import EventBase
-from synapse.types import DeviceListUpdates, JsonDict, UserID
+from synapse.types import DeviceListUpdates, JsonDict, JsonMapping, UserID
from synapse.util.caches.descriptors import _CacheContext, cached
if TYPE_CHECKING:
@@ -379,8 +379,8 @@ class AppServiceTransaction:
service: ApplicationService,
id: int,
events: Sequence[EventBase],
- ephemeral: List[JsonDict],
- to_device_messages: List[JsonDict],
+ ephemeral: List[JsonMapping],
+ to_device_messages: List[JsonMapping],
one_time_keys_count: TransactionOneTimeKeysCount,
unused_fallback_keys: TransactionUnusedFallbackKeys,
device_list_summary: DeviceListUpdates,
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index b1523be208..c42e1f11aa 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -41,7 +41,7 @@ from synapse.events import EventBase
from synapse.events.utils import SerializeEventConfig, serialize_event
from synapse.http.client import SimpleHttpClient, is_unknown_endpoint
from synapse.logging import opentracing
-from synapse.types import DeviceListUpdates, JsonDict, ThirdPartyInstanceID
+from synapse.types import DeviceListUpdates, JsonDict, JsonMapping, ThirdPartyInstanceID
from synapse.util.caches.response_cache import ResponseCache
if TYPE_CHECKING:
@@ -306,8 +306,8 @@ class ApplicationServiceApi(SimpleHttpClient):
self,
service: "ApplicationService",
events: Sequence[EventBase],
- ephemeral: List[JsonDict],
- to_device_messages: List[JsonDict],
+ ephemeral: List[JsonMapping],
+ to_device_messages: List[JsonMapping],
one_time_keys_count: TransactionOneTimeKeysCount,
unused_fallback_keys: TransactionUnusedFallbackKeys,
device_list_summary: DeviceListUpdates,
diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py
index 79f95f7653..18a30bc376 100644
--- a/synapse/appservice/scheduler.py
+++ b/synapse/appservice/scheduler.py
@@ -73,7 +73,7 @@ from synapse.events import EventBase
from synapse.logging.context import run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.databases.main import DataStore
-from synapse.types import DeviceListUpdates, JsonDict
+from synapse.types import DeviceListUpdates, JsonMapping
from synapse.util import Clock
if TYPE_CHECKING:
@@ -121,8 +121,8 @@ class ApplicationServiceScheduler:
self,
appservice: ApplicationService,
events: Optional[Collection[EventBase]] = None,
- ephemeral: Optional[Collection[JsonDict]] = None,
- to_device_messages: Optional[Collection[JsonDict]] = None,
+ ephemeral: Optional[Collection[JsonMapping]] = None,
+ to_device_messages: Optional[Collection[JsonMapping]] = None,
device_list_summary: Optional[DeviceListUpdates] = None,
) -> None:
"""
@@ -180,9 +180,9 @@ class _ServiceQueuer:
# dict of {service_id: [events]}
self.queued_events: Dict[str, List[EventBase]] = {}
# dict of {service_id: [events]}
- self.queued_ephemeral: Dict[str, List[JsonDict]] = {}
+ self.queued_ephemeral: Dict[str, List[JsonMapping]] = {}
# dict of {service_id: [to_device_message_json]}
- self.queued_to_device_messages: Dict[str, List[JsonDict]] = {}
+ self.queued_to_device_messages: Dict[str, List[JsonMapping]] = {}
# dict of {service_id: [device_list_summary]}
self.queued_device_list_summaries: Dict[str, List[DeviceListUpdates]] = {}
@@ -293,8 +293,8 @@ class _ServiceQueuer:
self,
service: ApplicationService,
events: Iterable[EventBase],
- ephemerals: Iterable[JsonDict],
- to_device_messages: Iterable[JsonDict],
+ ephemerals: Iterable[JsonMapping],
+ to_device_messages: Iterable[JsonMapping],
) -> Tuple[TransactionOneTimeKeysCount, TransactionUnusedFallbackKeys]:
"""
Given a list of the events, ephemeral messages and to-device messages,
@@ -364,8 +364,8 @@ class _TransactionController:
self,
service: ApplicationService,
events: Sequence[EventBase],
- ephemeral: Optional[List[JsonDict]] = None,
- to_device_messages: Optional[List[JsonDict]] = None,
+ ephemeral: Optional[List[JsonMapping]] = None,
+ to_device_messages: Optional[List[JsonMapping]] = None,
one_time_keys_count: Optional[TransactionOneTimeKeysCount] = None,
unused_fallback_keys: Optional[TransactionUnusedFallbackKeys] = None,
device_list_summary: Optional[DeviceListUpdates] = None,
diff --git a/synapse/config/_util.py b/synapse/config/_util.py
index acccca413b..746838eee3 100644
--- a/synapse/config/_util.py
+++ b/synapse/config/_util.py
@@ -11,10 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict, Type, TypeVar
+from typing import TYPE_CHECKING, Any, Dict, Type, TypeVar
import jsonschema
-from pydantic import BaseModel, ValidationError, parse_obj_as
+
+from synapse._pydantic_compat import HAS_PYDANTIC_V2
+
+if TYPE_CHECKING or HAS_PYDANTIC_V2:
+ from pydantic.v1 import BaseModel, ValidationError, parse_obj_as
+else:
+ from pydantic import BaseModel, ValidationError, parse_obj_as
from synapse.config._base import ConfigError
from synapse.types import JsonDict, StrSequence
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index cabe0d4397..9f830e7094 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -415,3 +415,7 @@ class ExperimentalConfig(Config):
LimitExceededError.include_retry_after_header = experimental.get(
"msc4041_enabled", False
)
+
+ self.msc4028_push_encrypted_events = experimental.get(
+ "msc4028_push_encrypted_events", False
+ )
diff --git a/synapse/config/workers.py b/synapse/config/workers.py
index 6567fb6bb0..f1766088fc 100644
--- a/synapse/config/workers.py
+++ b/synapse/config/workers.py
@@ -15,10 +15,16 @@
import argparse
import logging
-from typing import Any, Dict, List, Optional, Union
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import attr
-from pydantic import BaseModel, Extra, StrictBool, StrictInt, StrictStr
+
+from synapse._pydantic_compat import HAS_PYDANTIC_V2
+
+if TYPE_CHECKING or HAS_PYDANTIC_V2:
+ from pydantic.v1 import BaseModel, Extra, StrictBool, StrictInt, StrictStr
+else:
+ from pydantic import BaseModel, Extra, StrictBool, StrictInt, StrictStr
from synapse.config._base import (
Config,
diff --git a/synapse/events/validator.py b/synapse/events/validator.py
index 5da50cb0d2..83d9fb5813 100644
--- a/synapse/events/validator.py
+++ b/synapse/events/validator.py
@@ -12,10 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import collections.abc
-from typing import List, Type, Union, cast
+from typing import TYPE_CHECKING, List, Type, Union, cast
import jsonschema
-from pydantic import Field, StrictBool, StrictStr
+
+from synapse._pydantic_compat import HAS_PYDANTIC_V2
+
+if TYPE_CHECKING or HAS_PYDANTIC_V2:
+ from pydantic.v1 import Field, StrictBool, StrictStr
+else:
+ from pydantic import Field, StrictBool, StrictStr
from synapse.api.constants import (
MAX_ALIAS_LENGTH,
@@ -33,9 +39,9 @@ from synapse.events.utils import (
CANONICALJSON_MIN_INT,
validate_canonicaljson,
)
-from synapse.federation.federation_server import server_matches_acl_event
from synapse.http.servlet import validate_json_object
from synapse.rest.models import RequestBodyModel
+from synapse.storage.controllers.state import server_acl_evaluator_from_event
from synapse.types import EventID, JsonDict, RoomID, StrCollection, UserID
@@ -100,7 +106,10 @@ class EventValidator:
self._validate_retention(event)
elif event.type == EventTypes.ServerACL:
- if not server_matches_acl_event(config.server.server_name, event):
+ server_acl_evaluator = server_acl_evaluator_from_event(event)
+ if not server_acl_evaluator.server_matches_acl_event(
+ config.server.server_name
+ ):
raise SynapseError(
400, "Can't create an ACL event that denies the local server"
)
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 607013f121..c8bc46415d 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -64,7 +64,7 @@ from synapse.federation.transport.client import SendJoinResponse
from synapse.http.client import is_unknown_endpoint
from synapse.http.types import QueryParams
from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, tag_args, trace
-from synapse.types import JsonDict, UserID, get_domain_from_id
+from synapse.types import JsonDict, StrCollection, UserID, get_domain_from_id
from synapse.util.async_helpers import concurrently_execute
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.retryutils import NotRetryingDestination
@@ -1704,7 +1704,7 @@ class FederationClient(FederationBase):
async def timestamp_to_event(
self,
*,
- destinations: List[str],
+ destinations: StrCollection,
room_id: str,
timestamp: int,
direction: Direction,
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index f9915e5a3f..ec8e770430 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -29,10 +29,8 @@ from typing import (
Union,
)
-from matrix_common.regex import glob_to_regex
from prometheus_client import Counter, Gauge, Histogram
-from twisted.internet.abstract import isIPAddress
from twisted.python import failure
from synapse.api.constants import (
@@ -1324,75 +1322,13 @@ class FederationServer(FederationBase):
Raises:
AuthError if the server does not match the ACL
"""
- acl_event = await self._storage_controllers.state.get_current_state_event(
- room_id, EventTypes.ServerACL, ""
+ server_acl_evaluator = (
+ await self._storage_controllers.state.get_server_acl_for_room(room_id)
)
- if not acl_event or server_matches_acl_event(server_name, acl_event):
- return
-
- raise AuthError(code=403, msg="Server is banned from room")
-
-
-def server_matches_acl_event(server_name: str, acl_event: EventBase) -> bool:
- """Check if the given server is allowed by the ACL event
-
- Args:
- server_name: name of server, without any port part
- acl_event: m.room.server_acl event
-
- Returns:
- True if this server is allowed by the ACLs
- """
- logger.debug("Checking %s against acl %s", server_name, acl_event.content)
-
- # first of all, check if literal IPs are blocked, and if so, whether the
- # server name is a literal IP
- allow_ip_literals = acl_event.content.get("allow_ip_literals", True)
- if not isinstance(allow_ip_literals, bool):
- logger.warning("Ignoring non-bool allow_ip_literals flag")
- allow_ip_literals = True
- if not allow_ip_literals:
- # check for ipv6 literals. These start with '['.
- if server_name[0] == "[":
- return False
-
- # check for ipv4 literals. We can just lift the routine from twisted.
- if isIPAddress(server_name):
- return False
-
- # next, check the deny list
- deny = acl_event.content.get("deny", [])
- if not isinstance(deny, (list, tuple)):
- logger.warning("Ignoring non-list deny ACL %s", deny)
- deny = []
- for e in deny:
- if _acl_entry_matches(server_name, e):
- # logger.info("%s matched deny rule %s", server_name, e)
- return False
-
- # then the allow list.
- allow = acl_event.content.get("allow", [])
- if not isinstance(allow, (list, tuple)):
- logger.warning("Ignoring non-list allow ACL %s", allow)
- allow = []
- for e in allow:
- if _acl_entry_matches(server_name, e):
- # logger.info("%s matched allow rule %s", server_name, e)
- return True
-
- # everything else should be rejected.
- # logger.info("%s fell through", server_name)
- return False
-
-
-def _acl_entry_matches(server_name: str, acl_entry: Any) -> bool:
- if not isinstance(acl_entry, str):
- logger.warning(
- "Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry)
- )
- return False
- regex = glob_to_regex(acl_entry)
- return bool(regex.match(server_name))
+ if server_acl_evaluator and not server_acl_evaluator.server_matches_acl_event(
+ server_name
+ ):
+ raise AuthError(code=403, msg="Server is banned from room")
class FederationHandlerRegistry:
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 6429545c98..7de7bd3289 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -46,6 +46,7 @@ from synapse.storage.databases.main.directory import RoomAliasMapping
from synapse.types import (
DeviceListUpdates,
JsonDict,
+ JsonMapping,
RoomAlias,
RoomStreamToken,
StreamKeyType,
@@ -397,7 +398,7 @@ class ApplicationServicesHandler:
async def _handle_typing(
self, service: ApplicationService, new_token: int
- ) -> List[JsonDict]:
+ ) -> List[JsonMapping]:
"""
Return the typing events since the given stream token that the given application
service should receive.
@@ -432,7 +433,7 @@ class ApplicationServicesHandler:
async def _handle_receipts(
self, service: ApplicationService, new_token: int
- ) -> List[JsonDict]:
+ ) -> List[JsonMapping]:
"""
Return the latest read receipts that the given application service should receive.
@@ -471,7 +472,7 @@ class ApplicationServicesHandler:
service: ApplicationService,
users: Collection[Union[str, UserID]],
new_token: Optional[int],
- ) -> List[JsonDict]:
+ ) -> List[JsonMapping]:
"""
Return the latest presence updates that the given application service should receive.
@@ -491,7 +492,7 @@ class ApplicationServicesHandler:
A list of json dictionaries containing data derived from the presence events
that should be sent to the given application service.
"""
- events: List[JsonDict] = []
+ events: List[JsonMapping] = []
presence_source = self.event_sources.sources.presence
from_key = await self.store.get_type_stream_id_for_appservice(
service, "presence"
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index ad075497c8..8c6432035d 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Mapping, Optional, Tuple
import attr
from canonicaljson import encode_canonical_json
@@ -31,6 +31,7 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace
from synapse.types import (
JsonDict,
+ JsonMapping,
UserID,
get_domain_from_id,
get_verify_key_from_cross_signing_key,
@@ -272,11 +273,7 @@ class E2eKeysHandler:
delay_cancellation=True,
)
- ret = {"device_keys": results, "failures": failures}
-
- ret.update(cross_signing_keys)
-
- return ret
+ return {"device_keys": results, "failures": failures, **cross_signing_keys}
@trace
async def _query_devices_for_destination(
@@ -408,7 +405,7 @@ class E2eKeysHandler:
@cancellable
async def get_cross_signing_keys_from_cache(
self, query: Iterable[str], from_user_id: Optional[str]
- ) -> Dict[str, Dict[str, dict]]:
+ ) -> Dict[str, Dict[str, JsonMapping]]:
"""Get cross-signing keys for users from the database
Args:
@@ -551,16 +548,13 @@ class E2eKeysHandler:
self.config.federation.allow_device_name_lookup_over_federation
),
)
- ret = {"device_keys": res}
# add in the cross-signing keys
cross_signing_keys = await self.get_cross_signing_keys_from_cache(
device_keys_query, None
)
- ret.update(cross_signing_keys)
-
- return ret
+ return {"device_keys": res, **cross_signing_keys}
async def claim_local_one_time_keys(
self,
@@ -1127,7 +1121,7 @@ class E2eKeysHandler:
user_id: str,
master_key_id: str,
signed_master_key: JsonDict,
- stored_master_key: JsonDict,
+ stored_master_key: JsonMapping,
devices: Dict[str, Dict[str, JsonDict]],
) -> List["SignatureListItem"]:
"""Check signatures of a user's master key made by their devices.
@@ -1278,7 +1272,7 @@ class E2eKeysHandler:
async def _get_e2e_cross_signing_verify_key(
self, user_id: str, key_type: str, from_user_id: Optional[str] = None
- ) -> Tuple[JsonDict, str, VerifyKey]:
+ ) -> Tuple[JsonMapping, str, VerifyKey]:
"""Fetch locally or remotely query for a cross-signing public key.
First, attempt to fetch the cross-signing public key from storage.
@@ -1333,7 +1327,7 @@ class E2eKeysHandler:
self,
user: UserID,
desired_key_type: str,
- ) -> Optional[Tuple[Dict[str, Any], str, VerifyKey]]:
+ ) -> Optional[Tuple[JsonMapping, str, VerifyKey]]:
"""Queries cross-signing keys for a remote user and saves them to the database
Only the key specified by `key_type` will be returned, while all retrieved keys
@@ -1474,7 +1468,7 @@ def _check_device_signature(
user_id: str,
verify_key: VerifyKey,
signed_device: JsonDict,
- stored_device: JsonDict,
+ stored_device: JsonMapping,
) -> None:
"""Check that a signature on a device or cross-signing key is correct and
matches the copy of the device/key that we have stored. Throws an
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index eedde97ab0..0cc8e990d9 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -1538,7 +1538,7 @@ class FederationEventHandler:
logger.exception("Failed to resync device for %s", sender)
async def backfill_event_id(
- self, destinations: List[str], room_id: str, event_id: str
+ self, destinations: StrCollection, room_id: str, event_id: str
) -> PulledPduInfo:
"""Backfill a single event and persist it as a non-outlier which means
we also pull in all of the state and auth events necessary for it.
@@ -2342,6 +2342,12 @@ class FederationEventHandler:
# TODO retrieve the previous state, and exclude join -> join transitions
self._notifier.notify_user_joined_room(event.event_id, event.room_id)
+ # If this is a server ACL event, clear the cache in the storage controller.
+ if event.type == EventTypes.ServerACL:
+ self._state_storage_controller.get_server_acl_for_room.invalidate(
+ (event.room_id,)
+ )
+
def _sanity_check_event(self, ev: EventBase) -> None:
"""
Do some early sanity checks of a received event
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 5dc76ef588..5737f8014d 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -32,6 +32,7 @@ from synapse.storage.roommember import RoomsForUser
from synapse.streams.config import PaginationConfig
from synapse.types import (
JsonDict,
+ JsonMapping,
Requester,
RoomStreamToken,
StreamKeyType,
@@ -454,7 +455,7 @@ class InitialSyncHandler:
for s in states
]
- async def get_receipts() -> List[JsonDict]:
+ async def get_receipts() -> List[JsonMapping]:
receipts = await self.store.get_linearized_receipts_for_room(
room_id, to_key=now_token.receipt_key
)
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index c036578a3d..44dbbf81dd 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -1730,6 +1730,11 @@ class EventCreationHandler:
event.event_id, event.room_id
)
+ if event.type == EventTypes.ServerACL:
+ self._storage_controllers.state.get_server_acl_for_room.invalidate(
+ (event.room_id,)
+ )
+
await self._maybe_kick_guest_users(event, context)
if event.type == EventTypes.CanonicalAlias:
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 375c7d0901..7c7cda3e95 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -401,9 +401,9 @@ class BasePresenceHandler(abc.ABC):
states,
)
- for destination, host_states in hosts_to_states.items():
+ for destinations, host_states in hosts_to_states:
await self._federation.send_presence_to_destinations(
- host_states, [destination]
+ host_states, destinations
)
async def send_full_presence_to_users(self, user_ids: StrCollection) -> None:
@@ -1000,9 +1000,9 @@ class PresenceHandler(BasePresenceHandler):
list(to_federation_ping.values()),
)
- for destination, states in hosts_to_states.items():
+ for destinations, states in hosts_to_states:
await self._federation_queue.send_presence_to_destinations(
- states, [destination]
+ states, destinations
)
@wrap_as_background_process("handle_presence_timeouts")
@@ -2276,7 +2276,7 @@ async def get_interested_remotes(
store: DataStore,
presence_router: PresenceRouter,
states: List[UserPresenceState],
-) -> Dict[str, Set[UserPresenceState]]:
+) -> List[Tuple[StrCollection, Collection[UserPresenceState]]]:
"""Given a list of presence states figure out which remote servers
should be sent which.
@@ -2290,23 +2290,26 @@ async def get_interested_remotes(
Returns:
A map from destinations to presence states to send to that destination.
"""
- hosts_and_states: Dict[str, Set[UserPresenceState]] = {}
+ hosts_and_states: List[Tuple[StrCollection, Collection[UserPresenceState]]] = []
# First we look up the rooms each user is in (as well as any explicit
# subscriptions), then for each distinct room we look up the remote
# hosts in those rooms.
- room_ids_to_states, users_to_states = await get_interested_parties(
- store, presence_router, states
- )
+ for state in states:
+ room_ids = await store.get_rooms_for_user(state.user_id)
+ hosts: Set[str] = set()
+ for room_id in room_ids:
+ room_hosts = await store.get_current_hosts_in_room(room_id)
+ hosts.update(room_hosts)
+ hosts_and_states.append((hosts, [state]))
- for room_id, states in room_ids_to_states.items():
- hosts = await store.get_current_hosts_in_room(room_id)
- for host in hosts:
- hosts_and_states.setdefault(host, set()).update(states)
+ # Ask a presence routing module for any additional parties if one
+ # is loaded.
+ router_users_to_states = await presence_router.get_users_for_states(states)
- for user_id, states in users_to_states.items():
+ for user_id, user_states in router_users_to_states.items():
host = get_domain_from_id(user_id)
- hosts_and_states.setdefault(host, set()).update(states)
+ hosts_and_states.append(([host], user_states))
return hosts_and_states
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index c7edada353..a7a29b758b 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -19,6 +19,7 @@ from synapse.appservice import ApplicationService
from synapse.streams import EventSource
from synapse.types import (
JsonDict,
+ JsonMapping,
ReadReceipt,
StreamKeyType,
UserID,
@@ -204,15 +205,15 @@ class ReceiptsHandler:
await self.federation_sender.send_read_receipt(receipt)
-class ReceiptEventSource(EventSource[int, JsonDict]):
+class ReceiptEventSource(EventSource[int, JsonMapping]):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
self.config = hs.config
@staticmethod
def filter_out_private_receipts(
- rooms: Sequence[JsonDict], user_id: str
- ) -> List[JsonDict]:
+ rooms: Sequence[JsonMapping], user_id: str
+ ) -> List[JsonMapping]:
"""
Filters a list of serialized receipts (as returned by /sync and /initialSync)
and removes private read receipts of other users.
@@ -229,7 +230,7 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
The same as rooms, but filtered.
"""
- result = []
+ result: List[JsonMapping] = []
# Iterate through each room's receipt content.
for room in rooms:
@@ -282,7 +283,7 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
room_ids: Iterable[str],
is_guest: bool,
explicit_room_id: Optional[str] = None,
- ) -> Tuple[List[JsonDict], int]:
+ ) -> Tuple[List[JsonMapping], int]:
from_key = int(from_key)
to_key = self.get_current_key()
@@ -301,7 +302,7 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
async def get_new_events_as(
self, from_key: int, to_key: int, service: ApplicationService
- ) -> Tuple[List[JsonDict], int]:
+ ) -> Tuple[List[JsonMapping], int]:
"""Returns a set of new read receipt events that an appservice
may be interested in.
diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py
index db97f7aede..9b13448cdd 100644
--- a/synapse/handlers/relations.py
+++ b/synapse/handlers/relations.py
@@ -13,7 +13,17 @@
# limitations under the License.
import enum
import logging
-from typing import TYPE_CHECKING, Collection, Dict, FrozenSet, Iterable, List, Optional
+from typing import (
+ TYPE_CHECKING,
+ Collection,
+ Dict,
+ FrozenSet,
+ Iterable,
+ List,
+ Mapping,
+ Optional,
+ Sequence,
+)
import attr
@@ -245,7 +255,7 @@ class RelationsHandler:
async def get_references_for_events(
self, event_ids: Collection[str], ignored_users: FrozenSet[str] = frozenset()
- ) -> Dict[str, List[_RelatedEvent]]:
+ ) -> Mapping[str, Sequence[_RelatedEvent]]:
"""Get a list of references to the given events.
Args:
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 1a4d394eda..7bd42f635f 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -235,7 +235,7 @@ class SyncResult:
archived: List[ArchivedSyncResult]
to_device: List[JsonDict]
device_lists: DeviceListUpdates
- device_one_time_keys_count: JsonDict
+ device_one_time_keys_count: JsonMapping
device_unused_fallback_key_types: List[str]
def __bool__(self) -> bool:
@@ -1558,7 +1558,7 @@ class SyncHandler:
logger.debug("Fetching OTK data")
device_id = sync_config.device_id
- one_time_keys_count: JsonDict = {}
+ one_time_keys_count: JsonMapping = {}
unused_fallback_key_types: List[str] = []
if device_id:
# TODO: We should have a way to let clients differentiate between the states of:
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 4b4227003d..bdefa7f26f 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -26,7 +26,14 @@ from synapse.metrics.background_process_metrics import (
)
from synapse.replication.tcp.streams import TypingStream
from synapse.streams import EventSource
-from synapse.types import JsonDict, Requester, StrCollection, StreamKeyType, UserID
+from synapse.types import (
+ JsonDict,
+ JsonMapping,
+ Requester,
+ StrCollection,
+ StreamKeyType,
+ UserID,
+)
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.metrics import Measure
from synapse.util.retryutils import filter_destinations_by_retry_limiter
@@ -487,7 +494,7 @@ class TypingWriterHandler(FollowerTypingHandler):
raise Exception("Typing writer instance got typing info over replication")
-class TypingNotificationEventSource(EventSource[int, JsonDict]):
+class TypingNotificationEventSource(EventSource[int, JsonMapping]):
def __init__(self, hs: "HomeServer"):
self._main_store = hs.get_datastores().main
self.clock = hs.get_clock()
@@ -497,7 +504,7 @@ class TypingNotificationEventSource(EventSource[int, JsonDict]):
#
self.get_typing_handler = hs.get_typing_handler
- def _make_event_for(self, room_id: str) -> JsonDict:
+ def _make_event_for(self, room_id: str) -> JsonMapping:
typing = self.get_typing_handler()._room_typing[room_id]
return {
"type": EduTypes.TYPING,
@@ -507,7 +514,7 @@ class TypingNotificationEventSource(EventSource[int, JsonDict]):
async def get_new_events_as(
self, from_key: int, service: ApplicationService
- ) -> Tuple[List[JsonDict], int]:
+ ) -> Tuple[List[JsonMapping], int]:
"""Returns a set of new typing events that an appservice
may be interested in.
@@ -551,7 +558,7 @@ class TypingNotificationEventSource(EventSource[int, JsonDict]):
room_ids: Iterable[str],
is_guest: bool,
explicit_room_id: Optional[str] = None,
- ) -> Tuple[List[JsonDict], int]:
+ ) -> Tuple[List[JsonMapping], int]:
with Measure(self.clock, "typing.get_new_events"):
from_key = int(from_key)
handler = self.get_typing_handler()
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index 5d79d31579..d9d5655c95 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -28,8 +28,15 @@ from typing import (
overload,
)
-from pydantic import BaseModel, MissingError, PydanticValueError, ValidationError
-from pydantic.error_wrappers import ErrorWrapper
+from synapse._pydantic_compat import HAS_PYDANTIC_V2
+
+if TYPE_CHECKING or HAS_PYDANTIC_V2:
+ from pydantic.v1 import BaseModel, MissingError, PydanticValueError, ValidationError
+ from pydantic.v1.error_wrappers import ErrorWrapper
+else:
+ from pydantic import BaseModel, MissingError, PydanticValueError, ValidationError
+ from pydantic.error_wrappers import ErrorWrapper
+
from typing_extensions import Literal
from twisted.web.server import Request
diff --git a/synapse/media/_base.py b/synapse/media/_base.py
index 20cb8b9010..80c448de2b 100644
--- a/synapse/media/_base.py
+++ b/synapse/media/_base.py
@@ -50,6 +50,39 @@ TEXT_CONTENT_TYPES = [
"text/xml",
]
+# A list of all content types that are "safe" to be rendered inline in a browser.
+INLINE_CONTENT_TYPES = [
+ "text/css",
+ "text/plain",
+ "text/csv",
+ "application/json",
+ "application/ld+json",
+ # We allow some media files deemed as safe, which comes from the matrix-react-sdk.
+ # https://github.com/matrix-org/matrix-react-sdk/blob/a70fcfd0bcf7f8c85986da18001ea11597989a7c/src/utils/blobs.ts#L51
+ # SVGs are *intentionally* omitted.
+ "image/jpeg",
+ "image/gif",
+ "image/png",
+ "image/apng",
+ "image/webp",
+ "image/avif",
+ "video/mp4",
+ "video/webm",
+ "video/ogg",
+ "video/quicktime",
+ "audio/mp4",
+ "audio/webm",
+ "audio/aac",
+ "audio/mpeg",
+ "audio/ogg",
+ "audio/wave",
+ "audio/wav",
+ "audio/x-wav",
+ "audio/x-pn-wav",
+ "audio/flac",
+ "audio/x-flac",
+]
+
def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]:
"""Parses the server name, media ID and optional file name from the request URI
@@ -153,8 +186,13 @@ def add_file_headers(
request.setHeader(b"Content-Type", content_type.encode("UTF-8"))
- # Use a Content-Disposition of attachment to force download of media.
- disposition = "attachment"
+ # A strict subset of content types is allowed to be inlined so that they may
+ # be viewed directly in a browser. Other file types are forced to be downloads.
+ if media_type.lower() in INLINE_CONTENT_TYPES:
+ disposition = "inline"
+ else:
+ disposition = "attachment"
+
if upload_name:
# RFC6266 section 4.1 [1] defines both `filename` and `filename*`.
#
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 554634579e..14784312dc 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -131,7 +131,7 @@ class BulkPushRuleEvaluator:
async def _get_rules_for_event(
self,
event: EventBase,
- ) -> Dict[str, FilteredPushRules]:
+ ) -> Mapping[str, FilteredPushRules]:
"""Get the push rules for all users who may need to be notified about
the event.
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index ca8a76f77c..f4f2b29e96 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -205,6 +205,12 @@ class ReplicationDataHandler:
self.notifier.notify_user_joined_room(
row.data.event_id, row.data.room_id
)
+
+ # If this is a server ACL event, clear the cache in the storage controller.
+ if row.data.type == EventTypes.ServerACL:
+ self._state_storage_controller.get_server_acl_for_room.invalidate(
+ (row.data.room_id,)
+ )
elif stream_name == UnPartialStatedRoomStream.NAME:
for row in rows:
assert isinstance(row, UnPartialStatedRoomStreamRow)
@@ -333,7 +339,7 @@ class ReplicationDataHandler:
try:
await make_deferred_yieldable(deferred)
except defer.TimeoutError:
- logger.error(
+ logger.warning(
"Timed out waiting for repl stream %r to reach %s (%s)"
"; currently at: %s",
stream_name,
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index 7d0b4b55a0..e42dade246 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -16,7 +16,6 @@
# limitations under the License.
import logging
-import platform
from http import HTTPStatus
from typing import TYPE_CHECKING, Optional, Tuple
@@ -107,10 +106,7 @@ class VersionServlet(RestServlet):
PATTERNS = admin_patterns("/server_version$")
def __init__(self, hs: "HomeServer"):
- self.res = {
- "server_version": SYNAPSE_VERSION,
- "python_version": platform.python_version(),
- }
+ self.res = {"server_version": SYNAPSE_VERSION}
def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
return HTTPStatus.OK, self.res
diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py
index 49cd0805fd..e74a87af4d 100644
--- a/synapse/rest/client/account.py
+++ b/synapse/rest/client/account.py
@@ -18,7 +18,12 @@ import random
from typing import TYPE_CHECKING, List, Optional, Tuple
from urllib.parse import urlparse
-from pydantic import StrictBool, StrictStr, constr
+from synapse._pydantic_compat import HAS_PYDANTIC_V2
+
+if TYPE_CHECKING or HAS_PYDANTIC_V2:
+ from pydantic.v1 import StrictBool, StrictStr, constr
+else:
+ from pydantic import StrictBool, StrictStr, constr
from typing_extensions import Literal
from twisted.web.server import Request
diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py
index 925f037743..80ae937921 100644
--- a/synapse/rest/client/devices.py
+++ b/synapse/rest/client/devices.py
@@ -17,7 +17,12 @@ import logging
from http import HTTPStatus
from typing import TYPE_CHECKING, List, Optional, Tuple
-from pydantic import Extra, StrictStr
+from synapse._pydantic_compat import HAS_PYDANTIC_V2
+
+if TYPE_CHECKING or HAS_PYDANTIC_V2:
+ from pydantic.v1 import Extra, StrictStr
+else:
+ from pydantic import Extra, StrictStr
from synapse.api import errors
from synapse.api.errors import NotFoundError, SynapseError, UnrecognizedRequestError
diff --git a/synapse/rest/client/directory.py b/synapse/rest/client/directory.py
index 570bb52747..82944ca711 100644
--- a/synapse/rest/client/directory.py
+++ b/synapse/rest/client/directory.py
@@ -15,7 +15,13 @@
import logging
from typing import TYPE_CHECKING, List, Optional, Tuple
-from pydantic import StrictStr
+from synapse._pydantic_compat import HAS_PYDANTIC_V2
+
+if TYPE_CHECKING or HAS_PYDANTIC_V2:
+ from pydantic.v1 import StrictStr
+else:
+ from pydantic import StrictStr
+
from typing_extensions import Literal
from twisted.web.server import Request
diff --git a/synapse/rest/client/filter.py b/synapse/rest/client/filter.py
index 5da1e511a2..b5879496db 100644
--- a/synapse/rest/client/filter.py
+++ b/synapse/rest/client/filter.py
@@ -19,7 +19,7 @@ from synapse.api.errors import AuthError, NotFoundError, StoreError, SynapseErro
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.http.site import SynapseRequest
-from synapse.types import JsonDict, UserID
+from synapse.types import JsonDict, JsonMapping, UserID
from ._base import client_patterns, set_timeline_upper_limit
@@ -41,7 +41,7 @@ class GetFilterRestServlet(RestServlet):
async def on_GET(
self, request: SynapseRequest, user_id: str, filter_id: str
- ) -> Tuple[int, JsonDict]:
+ ) -> Tuple[int, JsonMapping]:
target_user = UserID.from_string(user_id)
requester = await self.auth.get_user_by_req(request)
diff --git a/synapse/rest/client/models.py b/synapse/rest/client/models.py
index 3d7940b0fc..880f79473c 100644
--- a/synapse/rest/client/models.py
+++ b/synapse/rest/client/models.py
@@ -13,7 +13,12 @@
# limitations under the License.
from typing import TYPE_CHECKING, Dict, Optional
-from pydantic import Extra, StrictInt, StrictStr, constr, validator
+from synapse._pydantic_compat import HAS_PYDANTIC_V2
+
+if TYPE_CHECKING or HAS_PYDANTIC_V2:
+ from pydantic.v1 import Extra, StrictInt, StrictStr, constr, validator
+else:
+ from pydantic import Extra, StrictInt, StrictStr, constr, validator
from synapse.rest.models import RequestBodyModel
from synapse.util.threepids import validate_email
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index 0aaa838d04..48c47058db 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -16,7 +16,13 @@ import logging
import re
from typing import TYPE_CHECKING, Dict, Mapping, Optional, Set, Tuple
-from pydantic import Extra, StrictInt, StrictStr
+from synapse._pydantic_compat import HAS_PYDANTIC_V2
+
+if TYPE_CHECKING or HAS_PYDANTIC_V2:
+ from pydantic.v1 import Extra, StrictInt, StrictStr
+else:
+ from pydantic import StrictInt, StrictStr, Extra
+
from signedjson.sign import sign_json
from twisted.web.server import Request
diff --git a/synapse/rest/models.py b/synapse/rest/models.py
index ac39cda8e5..de354a2135 100644
--- a/synapse/rest/models.py
+++ b/synapse/rest/models.py
@@ -1,4 +1,24 @@
-from pydantic import BaseModel, Extra
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from synapse._pydantic_compat import HAS_PYDANTIC_V2
+
+if TYPE_CHECKING or HAS_PYDANTIC_V2:
+ from pydantic.v1 import BaseModel, Extra
+else:
+ from pydantic import BaseModel, Extra
class RequestBodyModel(BaseModel):
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index 1752f95db8..b2e63aed1e 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -23,7 +23,6 @@ from typing import (
Generator,
Iterable,
List,
- Mapping,
Optional,
Sequence,
Set,
@@ -269,7 +268,7 @@ async def _get_power_level_for_sender(
async def _get_auth_chain_difference(
room_id: str,
- state_sets: Sequence[Mapping[Any, str]],
+ state_sets: Sequence[StateMap[str]],
unpersisted_events: Dict[str, EventBase],
state_res_store: StateResolutionStore,
) -> Set[str]:
@@ -405,7 +404,7 @@ def _seperate(
# mypy doesn't understand that discarding None above means that conflicted
# state is StateMap[Set[str]], not StateMap[Set[Optional[Str]]].
- return unconflicted_state, conflicted_state # type: ignore
+ return unconflicted_state, conflicted_state # type: ignore[return-value]
def _is_power_event(event: EventBase) -> bool:
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 99ebd96f84..12829d3d7d 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -31,8 +31,8 @@ from typing import (
)
import attr
-from pydantic import BaseModel
+from synapse._pydantic_compat import HAS_PYDANTIC_V2
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.engines import PostgresEngine
from synapse.storage.types import Connection, Cursor
@@ -41,6 +41,11 @@ from synapse.util import Clock, json_encoder
from . import engines
+if TYPE_CHECKING or HAS_PYDANTIC_V2:
+ from pydantic.v1 import BaseModel
+else:
+ from pydantic import BaseModel
+
if TYPE_CHECKING:
from synapse.server import HomeServer
from synapse.storage.database import DatabasePool, LoggingTransaction
diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py
index 278c7832ba..46957723a1 100644
--- a/synapse/storage/controllers/state.py
+++ b/synapse/storage/controllers/state.py
@@ -37,6 +37,7 @@ from synapse.storage.util.partial_state_events_tracker import (
PartialCurrentStateTracker,
PartialStateEventsTracker,
)
+from synapse.synapse_rust.acl import ServerAclEvaluator
from synapse.types import MutableStateMap, StateMap, get_domain_from_id
from synapse.types.state import StateFilter
from synapse.util.async_helpers import Linearizer
@@ -501,6 +502,31 @@ class StateStorageController:
return event.content.get("alias")
+ @cached()
+ async def get_server_acl_for_room(
+ self, room_id: str
+ ) -> Optional[ServerAclEvaluator]:
+ """Get the server ACL evaluator for room, if any
+
+ This does up-front parsing of the content to ignore bad data and pre-compile
+ regular expressions.
+
+ Args:
+ room_id: The room ID
+
+ Returns:
+ The server ACL evaluator, if any
+ """
+
+ acl_event = await self.get_current_state_event(
+ room_id, EventTypes.ServerACL, ""
+ )
+
+ if not acl_event:
+ return None
+
+ return server_acl_evaluator_from_event(acl_event)
+
@trace
@tag_args
async def get_current_state_deltas(
@@ -582,7 +608,7 @@ class StateStorageController:
@trace
@tag_args
- async def get_current_hosts_in_room_ordered(self, room_id: str) -> List[str]:
+ async def get_current_hosts_in_room_ordered(self, room_id: str) -> Tuple[str, ...]:
"""Get current hosts in room based on current state.
Blocks until we have full state for the given room. This only happens for rooms
@@ -760,3 +786,36 @@ class StateStorageController:
cache.state_group = object()
return frozenset(cache.hosts_to_joined_users)
+
+
+def server_acl_evaluator_from_event(acl_event: EventBase) -> "ServerAclEvaluator":
+ """
+ Create a ServerAclEvaluator from a m.room.server_acl event's content.
+
+ This does up-front parsing of the content to ignore bad data. It then creates
+ the ServerAclEvaluator which will pre-compile regular expressions from the globs.
+ """
+
+ # first of all, parse if literal IPs are blocked.
+ allow_ip_literals = acl_event.content.get("allow_ip_literals", True)
+ if not isinstance(allow_ip_literals, bool):
+ logger.warning("Ignoring non-bool allow_ip_literals flag")
+ allow_ip_literals = True
+
+ # next, parse the deny list by ignoring any non-strings.
+ deny = acl_event.content.get("deny", [])
+ if not isinstance(deny, (list, tuple)):
+ logger.warning("Ignoring non-list deny ACL %s", deny)
+ deny = []
+ else:
+ deny = [s for s in deny if isinstance(s, str)]
+
+ # then the allow list.
+ allow = acl_event.content.get("allow", [])
+ if not isinstance(allow, (list, tuple)):
+ logger.warning("Ignoring non-list allow ACL %s", allow)
+ allow = []
+ else:
+ allow = [s for s in allow if isinstance(s, str)]
+
+ return ServerAclEvaluator(allow_ip_literals, allow, deny)
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 697bc5651c..ca894edd5a 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -361,19 +361,7 @@ class LoggingTransaction:
@property
def description(
self,
- ) -> Optional[
- Sequence[
- Tuple[
- str,
- Optional[Any],
- Optional[int],
- Optional[int],
- Optional[int],
- Optional[int],
- Optional[int],
- ]
- ]
- ]:
+ ) -> Optional[Sequence[Any]]:
return self.txn.description
def execute_batch(self, sql: str, args: Iterable[Iterable[Any]]) -> None:
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 484db175d0..0553a0621a 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -45,7 +45,7 @@ from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.types import Cursor
from synapse.storage.util.sequence import build_sequence_generator
-from synapse.types import DeviceListUpdates, JsonDict
+from synapse.types import DeviceListUpdates, JsonMapping
from synapse.util import json_encoder
from synapse.util.caches.descriptors import _CacheContext, cached
@@ -268,8 +268,8 @@ class ApplicationServiceTransactionWorkerStore(
self,
service: ApplicationService,
events: Sequence[EventBase],
- ephemeral: List[JsonDict],
- to_device_messages: List[JsonDict],
+ ephemeral: List[JsonMapping],
+ to_device_messages: List[JsonMapping],
one_time_keys_count: TransactionOneTimeKeysCount,
unused_fallback_keys: TransactionUnusedFallbackKeys,
device_list_summary: DeviceListUpdates,
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 70faf4b1ec..df596f35f9 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -55,7 +55,12 @@ from synapse.storage.util.id_generators import (
AbstractStreamIdGenerator,
StreamIdGenerator,
)
-from synapse.types import JsonDict, StrCollection, get_verify_key_from_cross_signing_key
+from synapse.types import (
+ JsonDict,
+ JsonMapping,
+ StrCollection,
+ get_verify_key_from_cross_signing_key,
+)
from synapse.util import json_decoder, json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.lrucache import LruCache
@@ -746,7 +751,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
@cancellable
async def get_user_devices_from_cache(
self, user_ids: Set[str], user_and_device_ids: List[Tuple[str, str]]
- ) -> Tuple[Set[str], Dict[str, Mapping[str, JsonDict]]]:
+ ) -> Tuple[Set[str], Dict[str, Mapping[str, JsonMapping]]]:
"""Get the devices (and keys if any) for remote users from the cache.
Args:
@@ -766,13 +771,13 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
user_ids_not_in_cache = unique_user_ids - user_ids_in_cache
# First fetch all the users which all devices are to be returned.
- results: Dict[str, Mapping[str, JsonDict]] = {}
+ results: Dict[str, Mapping[str, JsonMapping]] = {}
for user_id in user_ids:
if user_id in user_ids_in_cache:
results[user_id] = await self.get_cached_devices_for_user(user_id)
# Then fetch all device-specific requests, but skip users we've already
# fetched all devices for.
- device_specific_results: Dict[str, Dict[str, JsonDict]] = {}
+ device_specific_results: Dict[str, Dict[str, JsonMapping]] = {}
for user_id, device_id in user_and_device_ids:
if user_id in user_ids_in_cache and user_id not in user_ids:
device = await self._get_cached_user_device(user_id, device_id)
@@ -801,7 +806,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
return user_ids_in_cache
@cached(num_args=2, tree=True)
- async def _get_cached_user_device(self, user_id: str, device_id: str) -> JsonDict:
+ async def _get_cached_user_device(
+ self, user_id: str, device_id: str
+ ) -> JsonMapping:
content = await self.db_pool.simple_select_one_onecol(
table="device_lists_remote_cache",
keyvalues={"user_id": user_id, "device_id": device_id},
@@ -811,7 +818,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
return db_to_json(content)
@cached()
- async def get_cached_devices_for_user(self, user_id: str) -> Mapping[str, JsonDict]:
+ async def get_cached_devices_for_user(
+ self, user_id: str
+ ) -> Mapping[str, JsonMapping]:
devices = await self.db_pool.simple_select_list(
table="device_lists_remote_cache",
keyvalues={"user_id": user_id},
@@ -1042,7 +1051,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
)
async def get_device_list_last_stream_id_for_remotes(
self, user_ids: Iterable[str]
- ) -> Dict[str, Optional[str]]:
+ ) -> Mapping[str, Optional[str]]:
rows = await self.db_pool.simple_select_many_batch(
table="device_lists_remote_extremeties",
column="user_id",
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index b49dea577c..89fac23f93 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -52,7 +52,7 @@ from synapse.storage.database import (
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import StreamIdGenerator
-from synapse.types import JsonDict
+from synapse.types import JsonDict, JsonMapping
from synapse.util import json_decoder, json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.cancellation import cancellable
@@ -125,7 +125,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
async def get_e2e_device_keys_for_federation_query(
self, user_id: str
- ) -> Tuple[int, List[JsonDict]]:
+ ) -> Tuple[int, Sequence[JsonMapping]]:
"""Get all devices (with any device keys) for a user
Returns:
@@ -174,7 +174,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
@cached(iterable=True)
async def _get_e2e_device_keys_for_federation_query_inner(
self, user_id: str
- ) -> List[JsonDict]:
+ ) -> Sequence[JsonMapping]:
"""Get all devices (with any device keys) for a user"""
devices = await self.get_e2e_device_keys_and_signatures([(user_id, None)])
@@ -578,7 +578,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
@cached(max_entries=10000)
async def count_e2e_one_time_keys(
self, user_id: str, device_id: str
- ) -> Dict[str, int]:
+ ) -> Mapping[str, int]:
"""Count the number of one time keys the server has for a device
Returns:
A mapping from algorithm to number of keys for that algorithm.
@@ -812,7 +812,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
async def get_e2e_cross_signing_key(
self, user_id: str, key_type: str, from_user_id: Optional[str] = None
- ) -> Optional[JsonDict]:
+ ) -> Optional[JsonMapping]:
"""Returns a user's cross-signing key.
Args:
@@ -833,7 +833,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
return user_keys.get(key_type)
@cached(num_args=1)
- def _get_bare_e2e_cross_signing_keys(self, user_id: str) -> Mapping[str, JsonDict]:
+ def _get_bare_e2e_cross_signing_keys(
+ self, user_id: str
+ ) -> Mapping[str, JsonMapping]:
"""Dummy function. Only used to make a cache for
_get_bare_e2e_cross_signing_keys_bulk.
"""
@@ -846,7 +848,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
)
async def _get_bare_e2e_cross_signing_keys_bulk(
self, user_ids: Iterable[str]
- ) -> Dict[str, Optional[Mapping[str, JsonDict]]]:
+ ) -> Mapping[str, Optional[Mapping[str, JsonMapping]]]:
"""Returns the cross-signing keys for a set of users. The output of this
function should be passed to _get_e2e_cross_signing_signatures_txn if
the signatures for the calling user need to be fetched.
@@ -860,15 +862,12 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
their user ID will map to None.
"""
- result = await self.db_pool.runInteraction(
+ return await self.db_pool.runInteraction(
"get_bare_e2e_cross_signing_keys_bulk",
self._get_bare_e2e_cross_signing_keys_bulk_txn,
user_ids,
)
- # The `Optional` comes from the `@cachedList` decorator.
- return cast(Dict[str, Optional[Mapping[str, JsonDict]]], result)
-
def _get_bare_e2e_cross_signing_keys_bulk_txn(
self,
txn: LoggingTransaction,
@@ -1026,7 +1025,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
@cancellable
async def get_e2e_cross_signing_keys_bulk(
self, user_ids: List[str], from_user_id: Optional[str] = None
- ) -> Dict[str, Optional[Mapping[str, JsonDict]]]:
+ ) -> Mapping[str, Optional[Mapping[str, JsonMapping]]]:
"""Returns the cross-signing keys for a set of users.
Args:
@@ -1043,7 +1042,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
if from_user_id:
result = cast(
- Dict[str, Optional[Mapping[str, JsonDict]]],
+ Dict[str, Optional[Mapping[str, JsonMapping]]],
await self.db_pool.runInteraction(
"get_e2e_cross_signing_signatures",
self._get_e2e_cross_signing_signatures_txn,
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 943666ed4f..8737a1370e 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -24,6 +24,7 @@ from typing import (
Dict,
Iterable,
List,
+ Mapping,
MutableMapping,
Optional,
Set,
@@ -1633,7 +1634,7 @@ class EventsWorkerStore(SQLBaseStore):
self,
room_id: str,
event_ids: Collection[str],
- ) -> Dict[str, bool]:
+ ) -> Mapping[str, bool]:
"""Helper for have_seen_events
Returns:
@@ -2329,7 +2330,7 @@ class EventsWorkerStore(SQLBaseStore):
@cachedList(cached_method_name="is_partial_state_event", list_name="event_ids")
async def get_partial_state_events(
self, event_ids: Collection[str]
- ) -> Dict[str, bool]:
+ ) -> Mapping[str, bool]:
"""Checks which of the given events have partial state
Args:
diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py
index 047de6283a..7d94685caf 100644
--- a/synapse/storage/databases/main/filtering.py
+++ b/synapse/storage/databases/main/filtering.py
@@ -25,7 +25,7 @@ from synapse.storage.database import (
LoggingTransaction,
)
from synapse.storage.engines import PostgresEngine
-from synapse.types import JsonDict, UserID
+from synapse.types import JsonDict, JsonMapping, UserID
from synapse.util.caches.descriptors import cached
if TYPE_CHECKING:
@@ -145,7 +145,7 @@ class FilteringWorkerStore(SQLBaseStore):
@cached(num_args=2)
async def get_user_filter(
self, user_id: UserID, filter_id: Union[int, str]
- ) -> JsonDict:
+ ) -> JsonMapping:
# filter_id is BIGINT UNSIGNED, so if it isn't a number, fail
# with a coherent error message rather than 500 M_UNKNOWN.
try:
diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py
index 41563371dc..889c578b9c 100644
--- a/synapse/storage/databases/main/keys.py
+++ b/synapse/storage/databases/main/keys.py
@@ -16,7 +16,7 @@
import itertools
import json
import logging
-from typing import Dict, Iterable, Optional, Tuple
+from typing import Dict, Iterable, Mapping, Optional, Tuple
from canonicaljson import encode_canonical_json
from signedjson.key import decode_verify_key_bytes
@@ -130,7 +130,7 @@ class KeyStore(CacheInvalidationWorkerStore):
)
async def get_server_keys_json(
self, server_name_and_key_ids: Iterable[Tuple[str, str]]
- ) -> Dict[Tuple[str, str], FetchKeyResult]:
+ ) -> Mapping[Tuple[str, str], FetchKeyResult]:
"""
Args:
server_name_and_key_ids:
@@ -200,7 +200,7 @@ class KeyStore(CacheInvalidationWorkerStore):
)
async def get_server_keys_json_for_remote(
self, server_name: str, key_ids: Iterable[str]
- ) -> Dict[str, Optional[FetchKeyResultForRemote]]:
+ ) -> Mapping[str, Optional[FetchKeyResultForRemote]]:
"""Fetch the cached keys for the given server/key IDs.
If we have multiple entries for a given key ID, returns the most recent.
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index b51d20ac26..194b4e031f 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -11,7 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, cast
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Dict,
+ Iterable,
+ List,
+ Mapping,
+ Optional,
+ Tuple,
+ cast,
+)
from synapse.api.presence import PresenceState, UserPresenceState
from synapse.replication.tcp.streams import PresenceStream
@@ -249,7 +259,7 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
)
async def get_presence_for_users(
self, user_ids: Iterable[str]
- ) -> Dict[str, UserPresenceState]:
+ ) -> Mapping[str, UserPresenceState]:
rows = await self.db_pool.simple_select_many_batch(
table="presence_stream",
column="user_id",
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index bec0dc2afe..923166974c 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -88,6 +88,7 @@ def _load_rules(
msc1767_enabled=experimental_config.msc1767_enabled,
msc3664_enabled=experimental_config.msc3664_enabled,
msc3381_polls_enabled=experimental_config.msc3381_polls_enabled,
+ msc4028_push_encrypted_events=experimental_config.msc4028_push_encrypted_events,
)
return filtered_rules
@@ -216,7 +217,7 @@ class PushRulesWorkerStore(
@cachedList(cached_method_name="get_push_rules_for_user", list_name="user_ids")
async def bulk_get_push_rules(
self, user_ids: Collection[str]
- ) -> Dict[str, FilteredPushRules]:
+ ) -> Mapping[str, FilteredPushRules]:
if not user_ids:
return {}
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index a074c43989..0231f9407b 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -43,7 +43,7 @@ from synapse.storage.util.id_generators import (
MultiWriterIdGenerator,
StreamIdGenerator,
)
-from synapse.types import JsonDict
+from synapse.types import JsonDict, JsonMapping
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -218,7 +218,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
@cached()
async def _get_receipts_for_user_with_orderings(
self, user_id: str, receipt_type: str
- ) -> JsonDict:
+ ) -> JsonMapping:
"""
Fetch receipts for all rooms that the given user is joined to.
@@ -258,7 +258,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
async def get_linearized_receipts_for_rooms(
self, room_ids: Iterable[str], to_key: int, from_key: Optional[int] = None
- ) -> List[dict]:
+ ) -> List[JsonMapping]:
"""Get receipts for multiple rooms for sending to clients.
Args:
@@ -287,7 +287,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
async def get_linearized_receipts_for_room(
self, room_id: str, to_key: int, from_key: Optional[int] = None
- ) -> Sequence[JsonDict]:
+ ) -> Sequence[JsonMapping]:
"""Get receipts for a single room for sending to clients.
Args:
@@ -310,7 +310,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
@cached(tree=True)
async def _get_linearized_receipts_for_room(
self, room_id: str, to_key: int, from_key: Optional[int] = None
- ) -> Sequence[JsonDict]:
+ ) -> Sequence[JsonMapping]:
"""See get_linearized_receipts_for_room"""
def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
@@ -353,7 +353,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
)
async def _get_linearized_receipts_for_rooms(
self, room_ids: Collection[str], to_key: int, from_key: Optional[int] = None
- ) -> Dict[str, Sequence[JsonDict]]:
+ ) -> Mapping[str, Sequence[JsonMapping]]:
if not room_ids:
return {}
@@ -415,7 +415,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
)
async def get_linearized_receipts_for_all_rooms(
self, to_key: int, from_key: Optional[int] = None
- ) -> Mapping[str, JsonDict]:
+ ) -> Mapping[str, JsonMapping]:
"""Get receipts for all rooms between two stream_ids, up
to a limit of the latest 100 read receipts.
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 96908f14ba..b67f780c10 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -465,7 +465,7 @@ class RelationsWorkerStore(SQLBaseStore):
@cachedList(cached_method_name="get_references_for_event", list_name="event_ids")
async def get_references_for_events(
self, event_ids: Collection[str]
- ) -> Mapping[str, Optional[List[_RelatedEvent]]]:
+ ) -> Mapping[str, Optional[Sequence[_RelatedEvent]]]:
"""Get a list of references to the given events.
Args:
@@ -519,7 +519,7 @@ class RelationsWorkerStore(SQLBaseStore):
@cachedList(cached_method_name="get_applicable_edit", list_name="event_ids")
async def get_applicable_edits(
self, event_ids: Collection[str]
- ) -> Dict[str, Optional[EventBase]]:
+ ) -> Mapping[str, Optional[EventBase]]:
"""Get the most recent edit (if any) that has happened for the given
events.
@@ -605,7 +605,7 @@ class RelationsWorkerStore(SQLBaseStore):
@cachedList(cached_method_name="get_thread_summary", list_name="event_ids")
async def get_thread_summaries(
self, event_ids: Collection[str]
- ) -> Dict[str, Optional[Tuple[int, EventBase]]]:
+ ) -> Mapping[str, Optional[Tuple[int, EventBase]]]:
"""Get the number of threaded replies and the latest reply (if any) for the given events.
Args:
@@ -779,7 +779,7 @@ class RelationsWorkerStore(SQLBaseStore):
@cachedList(cached_method_name="get_thread_participated", list_name="event_ids")
async def get_threads_participated(
self, event_ids: Collection[str], user_id: str
- ) -> Dict[str, bool]:
+ ) -> Mapping[str, bool]:
"""Get whether the requesting user participated in the given threads.
This is separate from get_thread_summaries since that can be cached across
@@ -931,7 +931,7 @@ class RelationsWorkerStore(SQLBaseStore):
room_id: str,
limit: int = 5,
from_token: Optional[ThreadsNextBatch] = None,
- ) -> Tuple[List[str], Optional[ThreadsNextBatch]]:
+ ) -> Tuple[Sequence[str], Optional[ThreadsNextBatch]]:
"""Get a list of thread IDs, ordered by topological ordering of their
latest reply.
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index fff259f74c..3755773faa 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -191,7 +191,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
)
async def get_subset_users_in_room_with_profiles(
self, room_id: str, user_ids: Collection[str]
- ) -> Dict[str, ProfileInfo]:
+ ) -> Mapping[str, ProfileInfo]:
"""Get a mapping from user ID to profile information for a list of users
in a given room.
@@ -676,7 +676,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
)
async def _get_rooms_for_users(
self, user_ids: Collection[str]
- ) -> Dict[str, FrozenSet[str]]:
+ ) -> Mapping[str, FrozenSet[str]]:
"""A batched version of `get_rooms_for_user`.
Returns:
@@ -881,7 +881,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
)
async def _get_user_ids_from_membership_event_ids(
self, event_ids: Iterable[str]
- ) -> Dict[str, Optional[str]]:
+ ) -> Mapping[str, Optional[str]]:
"""For given set of member event_ids check if they point to a join
event.
@@ -984,7 +984,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
)
@cached(iterable=True, max_entries=10000)
- async def get_current_hosts_in_room_ordered(self, room_id: str) -> List[str]:
+ async def get_current_hosts_in_room_ordered(self, room_id: str) -> Tuple[str, ...]:
"""
Get current hosts in room based on current state.
@@ -1013,12 +1013,14 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
# `get_users_in_room` rather than funky SQL.
domains = await self.get_current_hosts_in_room(room_id)
- return list(domains)
+ return tuple(domains)
# For PostgreSQL we can use a regex to pull out the domains from the
# joined users in `current_state_events` via regex.
- def get_current_hosts_in_room_ordered_txn(txn: LoggingTransaction) -> List[str]:
+ def get_current_hosts_in_room_ordered_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[str, ...]:
# Returns a list of servers currently joined in the room sorted by
# longest in the room first (aka. with the lowest depth). The
# heuristic of sorting by servers who have been in the room the
@@ -1043,7 +1045,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
"""
txn.execute(sql, (room_id,))
# `server_domain` will be `NULL` for malformed MXIDs with no colons.
- return [d for d, in txn if d is not None]
+ return tuple(d for d, in txn if d is not None)
return await self.db_pool.runInteraction(
"get_current_hosts_in_room_ordered", get_current_hosts_in_room_ordered_txn
@@ -1191,7 +1193,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
)
async def get_membership_from_event_ids(
self, member_event_ids: Iterable[str]
- ) -> Dict[str, Optional[EventIdMembership]]:
+ ) -> Mapping[str, Optional[EventIdMembership]]:
"""Get user_id and membership of a set of event IDs.
Returns:
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index ebb2ae964f..5eaaff5b68 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -14,7 +14,17 @@
# limitations under the License.
import collections.abc
import logging
-from typing import TYPE_CHECKING, Any, Collection, Dict, Iterable, Optional, Set, Tuple
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Collection,
+ Dict,
+ Iterable,
+ Mapping,
+ Optional,
+ Set,
+ Tuple,
+)
import attr
@@ -372,7 +382,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
)
async def _get_state_group_for_events(
self, event_ids: Collection[str]
- ) -> Dict[str, int]:
+ ) -> Mapping[str, int]:
"""Returns mapping event_id -> state_group.
Raises:
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index efd21b5bfc..8f70eff809 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -14,7 +14,7 @@
import logging
from enum import Enum
-from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, cast
+from typing import TYPE_CHECKING, Iterable, List, Mapping, Optional, Tuple, cast
import attr
from canonicaljson import encode_canonical_json
@@ -210,7 +210,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
)
async def get_destination_retry_timings_batch(
self, destinations: StrCollection
- ) -> Dict[str, Optional[DestinationRetryTimings]]:
+ ) -> Mapping[str, Optional[DestinationRetryTimings]]:
rows = await self.db_pool.simple_select_many_batch(
table="destinations",
iterable=destinations,
diff --git a/synapse/storage/databases/main/user_erasure_store.py b/synapse/storage/databases/main/user_erasure_store.py
index f79006533f..06fcbe5e54 100644
--- a/synapse/storage/databases/main/user_erasure_store.py
+++ b/synapse/storage/databases/main/user_erasure_store.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Dict, Iterable
+from typing import Iterable, Mapping
from synapse.storage.database import LoggingTransaction
from synapse.storage.databases.main import CacheInvalidationWorkerStore
@@ -40,7 +40,7 @@ class UserErasureWorkerStore(CacheInvalidationWorkerStore):
return bool(result)
@cachedList(cached_method_name="is_user_erased", list_name="user_ids")
- async def are_users_erased(self, user_ids: Iterable[str]) -> Dict[str, bool]:
+ async def are_users_erased(self, user_ids: Iterable[str]) -> Mapping[str, bool]:
"""
Checks which users in a list have requested erasure
diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py
index 5b8ba436d4..6ff533a129 100644
--- a/synapse/storage/databases/state/bg_updates.py
+++ b/synapse/storage/databases/state/bg_updates.py
@@ -94,6 +94,18 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
groups: List[int],
state_filter: Optional[StateFilter] = None,
) -> Mapping[int, StateMap[str]]:
+ """
+ Given a number of state groups, fetch the latest state for each group.
+
+ Args:
+ txn: The transaction object.
+ groups: The given state groups that you want to fetch the latest state for.
+ state_filter: The state filter to apply the state we fetch state from the database.
+
+ Returns:
+ Map from state_group to a StateMap at that point.
+ """
+
state_filter = state_filter or StateFilter.all()
results: Dict[int, MutableStateMap[str]] = {group: {} for group in groups}
@@ -206,8 +218,10 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
if where_clause:
where_clause = " AND (%s)" % (where_clause,)
- # We don't use WITH RECURSIVE on sqlite3 as there are distributions
- # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
+ # XXX: We could `WITH RECURSIVE` here since it's supported on SQLite 3.8.3
+ # or higher and our minimum supported version is greater than that.
+ #
+ # We just haven't put in the time to refactor this.
for group in groups:
next_group: Optional[int] = group
diff --git a/synapse/storage/types.py b/synapse/storage/types.py
index 34ac807530..afaeef9a5a 100644
--- a/synapse/storage/types.py
+++ b/synapse/storage/types.py
@@ -53,22 +53,10 @@ class Cursor(Protocol):
@property
def description(
self,
- ) -> Optional[
- Sequence[
- # Note that this is an approximate typing based on sqlite3 and other
- # drivers, and may not be entirely accurate.
- # FWIW, the DBAPI 2 spec is: https://peps.python.org/pep-0249/#description
- Tuple[
- str,
- Optional[Any],
- Optional[int],
- Optional[int],
- Optional[int],
- Optional[int],
- Optional[int],
- ]
- ]
- ]:
+ ) -> Optional[Sequence[Any]]:
+ # At the time of writing, Synapse only assumes that `column[0]: str` for each
+ # `column in description`. Since this is hard to express in the type system, and
+ # as this is rarely used in Synapse, we deem `column: Any` good enough.
...
@property
|