diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 3d7f986ac7..66e869bc2d 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -32,7 +32,6 @@ from synapse.appservice import ApplicationService
from synapse.http import get_request_user_agent
from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import (
- SynapseTags,
active_span,
force_tracing,
start_active_span,
@@ -162,12 +161,6 @@ class Auth:
parent_span.set_tag(
"authenticated_entity", requester.authenticated_entity
)
- # We tag the Synapse instance name so that it's an easy jumping
- # off point into the logs. Can also be used to filter for an
- # instance that is under load.
- parent_span.set_tag(
- SynapseTags.INSTANCE_NAME, self.hs.get_instance_name()
- )
parent_span.set_tag("user_id", requester.user.to_string())
if requester.device_id is not None:
parent_span.set_tag("device_id", requester.device_id)
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index c2c177fd71..9235ce6536 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -751,3 +751,25 @@ class ModuleFailedException(Exception):
Raised when a module API callback fails, for example because it raised an
exception.
"""
+
+
+class PartialStateConflictError(SynapseError):
+ """An internal error raised when attempting to persist an event with partial state
+ after the room containing the event has been un-partial stated.
+
+ This error should be handled by recomputing the event context and trying again.
+
+ This error has an HTTP status code so that it can be transported over replication.
+ It should not be exposed to clients.
+ """
+
+ @staticmethod
+ def message() -> str:
+ return "Cannot persist partial state event in un-partial stated room"
+
+ def __init__(self) -> None:
+ super().__init__(
+ HTTPStatus.CONFLICT,
+ msg=PartialStateConflictError.message(),
+ errcode=Codes.UNKNOWN,
+ )
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index 83c42fc25a..b9f432cc23 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -219,9 +219,13 @@ class FilterCollection:
self._room_timeline_filter = Filter(hs, room_filter_json.get("timeline", {}))
self._room_state_filter = Filter(hs, room_filter_json.get("state", {}))
self._room_ephemeral_filter = Filter(hs, room_filter_json.get("ephemeral", {}))
- self._room_account_data = Filter(hs, room_filter_json.get("account_data", {}))
+ self._room_account_data_filter = Filter(
+ hs, room_filter_json.get("account_data", {})
+ )
self._presence_filter = Filter(hs, filter_json.get("presence", {}))
- self._account_data = Filter(hs, filter_json.get("account_data", {}))
+ self._global_account_data_filter = Filter(
+ hs, filter_json.get("account_data", {})
+ )
self.include_leave = filter_json.get("room", {}).get("include_leave", False)
self.event_fields = filter_json.get("event_fields", [])
@@ -256,8 +260,10 @@ class FilterCollection:
) -> List[UserPresenceState]:
return await self._presence_filter.filter(presence_states)
- async def filter_account_data(self, events: Iterable[JsonDict]) -> List[JsonDict]:
- return await self._account_data.filter(events)
+ async def filter_global_account_data(
+ self, events: Iterable[JsonDict]
+ ) -> List[JsonDict]:
+ return await self._global_account_data_filter.filter(events)
async def filter_room_state(self, events: Iterable[EventBase]) -> List[EventBase]:
return await self._room_state_filter.filter(
@@ -279,7 +285,7 @@ class FilterCollection:
async def filter_room_account_data(
self, events: Iterable[JsonDict]
) -> List[JsonDict]:
- return await self._room_account_data.filter(
+ return await self._room_account_data_filter.filter(
await self._room_filter.filter(events)
)
@@ -292,6 +298,13 @@ class FilterCollection:
or self._presence_filter.filters_all_senders()
)
+ def blocks_all_global_account_data(self) -> bool:
+ """True if all global acount data will be filtered out."""
+ return (
+ self._global_account_data_filter.filters_all_types()
+ or self._global_account_data_filter.filters_all_senders()
+ )
+
def blocks_all_room_ephemeral(self) -> bool:
return (
self._room_ephemeral_filter.filters_all_types()
@@ -299,6 +312,13 @@ class FilterCollection:
or self._room_ephemeral_filter.filters_all_rooms()
)
+ def blocks_all_room_account_data(self) -> bool:
+ return (
+ self._room_account_data_filter.filters_all_types()
+ or self._room_account_data_filter.filters_all_senders()
+ or self._room_account_data_filter.filters_all_rooms()
+ )
+
def blocks_all_room_timeline(self) -> bool:
return (
self._room_timeline_filter.filters_all_types()
diff --git a/synapse/app/phone_stats_home.py b/synapse/app/phone_stats_home.py
index 53db1e85b3..897dd3edac 100644
--- a/synapse/app/phone_stats_home.py
+++ b/synapse/app/phone_stats_home.py
@@ -15,7 +15,7 @@ import logging
import math
import resource
import sys
-from typing import TYPE_CHECKING, List, Sized, Tuple
+from typing import TYPE_CHECKING, List, Mapping, Sized, Tuple
from prometheus_client import Gauge
@@ -194,7 +194,7 @@ def start_phone_stats_home(hs: "HomeServer") -> None:
@wrap_as_background_process("generate_monthly_active_users")
async def generate_monthly_active_users() -> None:
current_mau_count = 0
- current_mau_count_by_service = {}
+ current_mau_count_by_service: Mapping[str, int] = {}
reserved_users: Sized = ()
store = hs.get_datastores().main
if hs.config.server.limit_usage_by_mau or hs.config.server.mau_stats_only:
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index 53c0682dfd..6ac2f0c10d 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -169,6 +169,16 @@ class ExperimentalConfig(Config):
# MSC3925: do not replace events with their edits
self.msc3925_inhibit_edit = experimental.get("msc3925_inhibit_edit", False)
+ # MSC3758: exact_event_match push rule condition
+ self.msc3758_exact_event_match = experimental.get(
+ "msc3758_exact_event_match", False
+ )
+
+ # MSC3873: Disambiguate event_match keys.
+ self.msc3783_escape_event_match_key = experimental.get(
+ "msc3783_escape_event_match_key", False
+ )
+
# MSC3952: Intentional mentions
self.msc3952_intentional_mentions = experimental.get(
"msc3952_intentional_mentions", False
diff --git a/synapse/config/room_directory.py b/synapse/config/room_directory.py
index 3ed236217f..8666c22f01 100644
--- a/synapse/config/room_directory.py
+++ b/synapse/config/room_directory.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, List
+from typing import Any, Collection
from matrix_common.regex import glob_to_regex
@@ -70,7 +70,7 @@ class RoomDirectoryConfig(Config):
return False
def is_publishing_room_allowed(
- self, user_id: str, room_id: str, aliases: List[str]
+ self, user_id: str, room_id: str, aliases: Collection[str]
) -> bool:
"""Checks if the given user is allowed to publish the room
@@ -122,7 +122,7 @@ class _RoomDirectoryRule:
except Exception as e:
raise ConfigError("Failed to parse glob into regex") from e
- def matches(self, user_id: str, room_id: str, aliases: List[str]) -> bool:
+ def matches(self, user_id: str, room_id: str, aliases: Collection[str]) -> bool:
"""Tests if this rule matches the given user_id, room_id and aliases.
Args:
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index e0be9f88cc..4d6d1b8ebd 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -16,18 +16,7 @@
import collections.abc
import logging
import typing
-from typing import (
- Any,
- Collection,
- Dict,
- Iterable,
- List,
- Mapping,
- Optional,
- Set,
- Tuple,
- Union,
-)
+from typing import Any, Dict, Iterable, List, Mapping, Optional, Set, Tuple, Union
from canonicaljson import encode_canonical_json
from signedjson.key import decode_verify_key_bytes
@@ -56,7 +45,13 @@ from synapse.api.room_versions import (
RoomVersions,
)
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
-from synapse.types import MutableStateMap, StateMap, UserID, get_domain_from_id
+from synapse.types import (
+ MutableStateMap,
+ StateMap,
+ StrCollection,
+ UserID,
+ get_domain_from_id,
+)
if typing.TYPE_CHECKING:
# conditional imports to avoid import cycle
@@ -69,7 +64,7 @@ logger = logging.getLogger(__name__)
class _EventSourceStore(Protocol):
async def get_events(
self,
- event_ids: Collection[str],
+ event_ids: StrCollection,
redact_behaviour: EventRedactBehaviour,
get_prev_content: bool = False,
allow_rejected: bool = False,
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index 8aca9a3ab9..91118a8d84 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -39,7 +39,7 @@ from unpaddedbase64 import encode_base64
from synapse.api.constants import RelationTypes
from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions
-from synapse.types import JsonDict, RoomStreamToken
+from synapse.types import JsonDict, RoomStreamToken, StrCollection
from synapse.util.caches import intern_dict
from synapse.util.frozenutils import freeze
from synapse.util.stringutils import strtobool
@@ -413,7 +413,7 @@ class EventBase(metaclass=abc.ABCMeta):
"""
return [e for e, _ in self._dict["prev_events"]]
- def auth_event_ids(self) -> Sequence[str]:
+ def auth_event_ids(self) -> StrCollection:
"""Returns the list of auth event IDs. The order matches the order
specified in the event, though there is no meaning to it.
@@ -558,7 +558,7 @@ class FrozenEventV2(EventBase):
"""
return self._dict["prev_events"]
- def auth_event_ids(self) -> Sequence[str]:
+ def auth_event_ids(self) -> StrCollection:
"""Returns the list of auth event IDs. The order matches the order
specified in the event, though there is no meaning to it.
diff --git a/synapse/events/builder.py b/synapse/events/builder.py
index 94dd1298e1..c82745275f 100644
--- a/synapse/events/builder.py
+++ b/synapse/events/builder.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Tuple, Union
import attr
from signedjson.types import SigningKey
@@ -103,7 +103,7 @@ class EventBuilder:
async def build(
self,
- prev_event_ids: List[str],
+ prev_event_ids: Collection[str],
auth_event_ids: Optional[List[str]],
depth: Optional[int] = None,
) -> EventBase:
@@ -136,7 +136,7 @@ class EventBuilder:
format_version = self.room_version.event_format
# The types of auth/prev events changes between event versions.
- prev_events: Union[List[str], List[Tuple[str, Dict[str, str]]]]
+ prev_events: Union[Collection[str], List[Tuple[str, Dict[str, str]]]]
auth_events: Union[List[str], List[Tuple[str, Dict[str, str]]]]
if format_version == EventFormatVersions.ROOM_V1_V2:
auth_events = await self._store.add_event_hashes(auth_event_ids)
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index 6eaef8b57a..e0d82ad81c 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, List, Optional, Tuple
import attr
@@ -26,8 +27,51 @@ if TYPE_CHECKING:
from synapse.types.state import StateFilter
+class UnpersistedEventContextBase(ABC):
+ """
+ This is a base class for EventContext and UnpersistedEventContext, objects which
+ hold information relevant to storing an associated event. Note that an
+ UnpersistedEventContexts must be converted into an EventContext before it is
+ suitable to send to the db with its associated event.
+
+ Attributes:
+ _storage: storage controllers for interfacing with the database
+ app_service: If the associated event is being sent by a (local) application service, that
+ app service.
+ """
+
+ def __init__(self, storage_controller: "StorageControllers"):
+ self._storage: "StorageControllers" = storage_controller
+ self.app_service: Optional[ApplicationService] = None
+
+ @abstractmethod
+ async def persist(
+ self,
+ event: EventBase,
+ ) -> "EventContext":
+ """
+ A method to convert an UnpersistedEventContext to an EventContext, suitable for
+ sending to the database with the associated event.
+ """
+ pass
+
+ @abstractmethod
+ async def get_prev_state_ids(
+ self, state_filter: Optional["StateFilter"] = None
+ ) -> StateMap[str]:
+ """
+ Gets the room state at the event (ie not including the event if the event is a
+ state event).
+
+ Args:
+ state_filter: specifies the type of state event to fetch from DB, example:
+ EventTypes.JoinRules
+ """
+ pass
+
+
@attr.s(slots=True, auto_attribs=True)
-class EventContext:
+class EventContext(UnpersistedEventContextBase):
"""
Holds information relevant to persisting an event
@@ -77,9 +121,6 @@ class EventContext:
delta_ids: If ``prev_group`` is not None, the state delta between ``prev_group``
and ``state_group``.
- app_service: If this event is being sent by a (local) application service, that
- app service.
-
partial_state: if True, we may be storing this event with a temporary,
incomplete state.
"""
@@ -122,6 +163,9 @@ class EventContext:
"""Return an EventContext instance suitable for persisting an outlier event"""
return EventContext(storage=storage)
+ async def persist(self, event: EventBase) -> "EventContext":
+ return self
+
async def serialize(self, event: EventBase, store: "DataStore") -> JsonDict:
"""Converts self to a type that can be serialized as JSON, and then
deserialized by `deserialize`
@@ -254,6 +298,128 @@ class EventContext:
)
+@attr.s(slots=True, auto_attribs=True)
+class UnpersistedEventContext(UnpersistedEventContextBase):
+ """
+ The event context holds information about the state groups for an event. It is important
+ to remember that an event technically has two state groups: the state group before the
+ event, and the state group after the event. If the event is not a state event, the state
+ group will not change (ie the state group before the event will be the same as the state
+ group after the event), but if it is a state event the state group before the event
+ will differ from the state group after the event.
+ This is a version of an EventContext before the new state group (if any) has been
+ computed and stored. It contains information about the state before the event (which
+ also may be the information after the event, if the event is not a state event). The
+ UnpersistedEventContext must be converted into an EventContext by calling the method
+ 'persist' on it before it is suitable to be sent to the DB for processing.
+
+ state_group_after_event:
+ The state group after the event. This will always be None until it is persisted.
+ If the event is not a state event, this will be the same as
+ state_group_before_event.
+
+ state_group_before_event:
+ The ID of the state group representing the state of the room before this event.
+
+ state_delta_due_to_event:
+ If the event is a state event, then this is the delta of the state between
+ `state_group` and `state_group_before_event`
+
+ prev_group_for_state_group_before_event:
+ If it is known, ``state_group_before_event``'s previous state group.
+
+ delta_ids_to_state_group_before_event:
+ If ``prev_group_for_state_group_before_event`` is not None, the state delta
+ between ``prev_group_for_state_group_before_event`` and ``state_group_before_event``.
+
+ partial_state:
+ Whether the event has partial state.
+
+ state_map_before_event:
+ A map of the state before the event, i.e. the state at `state_group_before_event`
+ """
+
+ _storage: "StorageControllers"
+ state_group_before_event: Optional[int]
+ state_group_after_event: Optional[int]
+ state_delta_due_to_event: Optional[dict]
+ prev_group_for_state_group_before_event: Optional[int]
+ delta_ids_to_state_group_before_event: Optional[StateMap[str]]
+ partial_state: bool
+ state_map_before_event: Optional[StateMap[str]] = None
+
+ async def get_prev_state_ids(
+ self, state_filter: Optional["StateFilter"] = None
+ ) -> StateMap[str]:
+ """
+ Gets the room state map, excluding this event.
+
+ Args:
+ state_filter: specifies the type of state event to fetch from DB
+
+ Returns:
+ Maps a (type, state_key) to the event ID of the state event matching
+ this tuple.
+ """
+ if self.state_map_before_event:
+ return self.state_map_before_event
+
+ assert self.state_group_before_event is not None
+ return await self._storage.state.get_state_ids_for_group(
+ self.state_group_before_event, state_filter
+ )
+
+ async def persist(self, event: EventBase) -> EventContext:
+ """
+ Creates a full `EventContext` for the event, persisting any referenced state that
+ has not yet been persisted.
+
+ Args:
+ event: event that the EventContext is associated with.
+
+ Returns: An EventContext suitable for sending to the database with the event
+ for persisting
+ """
+ assert self.partial_state is not None
+
+ # If we have a full set of state for before the event but don't have a state
+ # group for that state, we need to get one
+ if self.state_group_before_event is None:
+ assert self.state_map_before_event
+ state_group_before_event = await self._storage.state.store_state_group(
+ event.event_id,
+ event.room_id,
+ prev_group=self.prev_group_for_state_group_before_event,
+ delta_ids=self.delta_ids_to_state_group_before_event,
+ current_state_ids=self.state_map_before_event,
+ )
+ self.state_group_before_event = state_group_before_event
+
+ # if the event isn't a state event the state group doesn't change
+ if not self.state_delta_due_to_event:
+ state_group_after_event = self.state_group_before_event
+
+ # otherwise if it is a state event we need to get a state group for it
+ else:
+ state_group_after_event = await self._storage.state.store_state_group(
+ event.event_id,
+ event.room_id,
+ prev_group=self.state_group_before_event,
+ delta_ids=self.state_delta_due_to_event,
+ current_state_ids=None,
+ )
+
+ return EventContext.with_state(
+ storage=self._storage,
+ state_group=state_group_after_event,
+ state_group_before_event=self.state_group_before_event,
+ state_delta_due_to_event=self.state_delta_due_to_event,
+ partial_state=self.partial_state,
+ prev_group=self.state_group_before_event,
+ delta_ids=self.state_delta_due_to_event,
+ )
+
+
def _encode_state_dict(
state_dict: Optional[StateMap[str]],
) -> Optional[List[Tuple[str, str, str]]]:
diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py
index 72ab696898..97c61cc258 100644
--- a/synapse/events/third_party_rules.py
+++ b/synapse/events/third_party_rules.py
@@ -18,7 +18,7 @@ from twisted.internet.defer import CancelledError
from synapse.api.errors import ModuleFailedException, SynapseError
from synapse.events import EventBase
-from synapse.events.snapshot import EventContext
+from synapse.events.snapshot import UnpersistedEventContextBase
from synapse.storage.roommember import ProfileInfo
from synapse.types import Requester, StateMap
from synapse.util.async_helpers import delay_cancellation, maybe_awaitable
@@ -231,7 +231,9 @@ class ThirdPartyEventRules:
self._on_threepid_bind_callbacks.append(on_threepid_bind)
async def check_event_allowed(
- self, event: EventBase, context: EventContext
+ self,
+ event: EventBase,
+ context: UnpersistedEventContextBase,
) -> Tuple[bool, Optional[dict]]:
"""Check if a provided event should be allowed in the given context.
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 8d36172484..6d99845de5 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -23,6 +23,7 @@ from typing import (
Collection,
Dict,
List,
+ Mapping,
Optional,
Tuple,
Union,
@@ -47,6 +48,7 @@ from synapse.api.errors import (
FederationError,
IncompatibleRoomVersionError,
NotFoundError,
+ PartialStateConflictError,
SynapseError,
UnsupportedRoomVersionError,
)
@@ -80,7 +82,6 @@ from synapse.replication.http.federation import (
ReplicationFederationSendEduRestServlet,
ReplicationGetQueryRestServlet,
)
-from synapse.storage.databases.main.events import PartialStateConflictError
from synapse.storage.databases.main.lock import Lock
from synapse.storage.databases.main.roommember import extract_heroes_from_room_summary
from synapse.storage.roommember import MemberSummary
@@ -1512,7 +1513,7 @@ class FederationHandlerRegistry:
def _get_event_ids_for_partial_state_join(
join_event: EventBase,
prev_state_ids: StateMap[str],
- summary: Dict[str, MemberSummary],
+ summary: Mapping[str, MemberSummary],
) -> Collection[str]:
"""Calculate state to be returned in a partial_state send_join
diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py
index 67e789eef7..797de46dbc 100644
--- a/synapse/handlers/account_data.py
+++ b/synapse/handlers/account_data.py
@@ -343,10 +343,12 @@ class AccountDataEventSource(EventSource[int, JsonDict]):
}
)
- (
- account_data,
- room_account_data,
- ) = await self.store.get_updated_account_data_for_user(user_id, last_stream_id)
+ account_data = await self.store.get_updated_global_account_data_for_user(
+ user_id, last_stream_id
+ )
+ room_account_data = await self.store.get_updated_room_account_data_for_user(
+ user_id, last_stream_id
+ )
for account_data_type, content in account_data.items():
results.append({"type": account_data_type, "content": content})
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 30f2d46c3c..57a6854b1e 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -1593,9 +1593,8 @@ class AuthHandler:
if medium == "email":
address = canonicalise_email(address)
- identity_handler = self.hs.get_identity_handler()
- result = await identity_handler.try_unbind_threepid(
- user_id, {"medium": medium, "address": address, "id_server": id_server}
+ result = await self.hs.get_identity_handler().try_unbind_threepid(
+ user_id, medium, address, id_server
)
await self.store.user_delete_threepid(user_id, medium, address)
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index d74d135c0c..d24f649382 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -106,12 +106,7 @@ class DeactivateAccountHandler:
for threepid in threepids:
try:
result = await self._identity_handler.try_unbind_threepid(
- user_id,
- {
- "medium": threepid["medium"],
- "address": threepid["address"],
- "id_server": id_server,
- },
+ user_id, threepid["medium"], threepid["address"], id_server
)
identity_server_supports_unbinding &= result
except Exception:
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 2ea52257cb..a5798e9483 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -14,7 +14,7 @@
import logging
import string
-from typing import TYPE_CHECKING, Iterable, List, Optional
+from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence
from typing_extensions import Literal
@@ -485,7 +485,8 @@ class DirectoryHandler:
)
)
if canonical_alias:
- room_aliases.append(canonical_alias)
+ # Ensure we do not mutate room_aliases.
+ room_aliases = list(room_aliases) + [canonical_alias]
if not self.config.roomdirectory.is_publishing_room_allowed(
user_id, room_id, room_aliases
@@ -528,7 +529,7 @@ class DirectoryHandler:
async def get_aliases_for_room(
self, requester: Requester, room_id: str
- ) -> List[str]:
+ ) -> Sequence[str]:
"""
Get a list of the aliases that currently point to this room on this server
"""
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index d2188ca08f..43cbece21b 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -159,19 +159,22 @@ class E2eKeysHandler:
# A map of destination -> user ID -> device IDs.
remote_queries_not_in_cache: Dict[str, Dict[str, Iterable[str]]] = {}
if remote_queries:
- query_list: List[Tuple[str, Optional[str]]] = []
+ user_ids = set()
+ user_and_device_ids: List[Tuple[str, str]] = []
for user_id, device_ids in remote_queries.items():
if device_ids:
- query_list.extend(
+ user_and_device_ids.extend(
(user_id, device_id) for device_id in device_ids
)
else:
- query_list.append((user_id, None))
+ user_ids.add(user_id)
(
user_ids_not_in_cache,
remote_results,
- ) = await self.store.get_user_devices_from_cache(query_list)
+ ) = await self.store.get_user_devices_from_cache(
+ user_ids, user_and_device_ids
+ )
# Check that the homeserver still shares a room with all cached users.
# Note that this check may be slightly racy when a remote user leaves a
diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py
index a23a8ce2a1..46dd63c3f0 100644
--- a/synapse/handlers/event_auth.py
+++ b/synapse/handlers/event_auth.py
@@ -202,7 +202,7 @@ class EventAuthHandler:
state_ids: StateMap[str],
room_version: RoomVersion,
user_id: str,
- prev_member_event: Optional[EventBase],
+ prev_membership: Optional[str],
) -> None:
"""
Check whether a user can join a room without an invite due to restricted join rules.
@@ -214,15 +214,14 @@ class EventAuthHandler:
state_ids: The state of the room as it currently is.
room_version: The room version of the room being joined.
user_id: The user joining the room.
- prev_member_event: The current membership event for this user.
+ prev_membership: The current membership state for this user. `None` if the
+ user has never joined the room (equivalent to "leave").
Raises:
AuthError if the user cannot join the room.
"""
# If the member is invited or currently joined, then nothing to do.
- if prev_member_event and (
- prev_member_event.membership in (Membership.JOIN, Membership.INVITE)
- ):
+ if prev_membership in (Membership.JOIN, Membership.INVITE):
return
# This is not a room with a restricted join rule, so we don't need to do the
@@ -255,13 +254,14 @@ class EventAuthHandler:
)
async def has_restricted_join_rules(
- self, state_ids: StateMap[str], room_version: RoomVersion
+ self, partial_state_ids: StateMap[str], room_version: RoomVersion
) -> bool:
"""
Return if the room has the proper join rules set for access via rooms.
Args:
- state_ids: The state of the room as it currently is.
+ state_ids: The state of the room as it currently is. May be full or partial
+ state.
room_version: The room version of the room to query.
Returns:
@@ -272,7 +272,7 @@ class EventAuthHandler:
return False
# If there's no join rule, then it defaults to invite (so this doesn't apply).
- join_rules_event_id = state_ids.get((EventTypes.JoinRules, ""), None)
+ join_rules_event_id = partial_state_ids.get((EventTypes.JoinRules, ""), None)
if not join_rules_event_id:
return False
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 7f64130e0a..08727e4857 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -49,6 +49,7 @@ from synapse.api.errors import (
FederationPullAttemptBackoffError,
HttpResponseException,
NotFoundError,
+ PartialStateConflictError,
RequestSendFailed,
SynapseError,
)
@@ -56,7 +57,7 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.crypto.event_signing import compute_event_signature
from synapse.event_auth import validate_event_for_room_version
from synapse.events import EventBase
-from synapse.events.snapshot import EventContext
+from synapse.events.snapshot import EventContext, UnpersistedEventContextBase
from synapse.events.validator import EventValidator
from synapse.federation.federation_client import InvalidResponseError
from synapse.http.servlet import assert_params_in_dict
@@ -68,7 +69,6 @@ from synapse.replication.http.federation import (
ReplicationCleanRoomRestServlet,
ReplicationStoreRoomOnOutlierMembershipRestServlet,
)
-from synapse.storage.databases.main.events import PartialStateConflictError
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.types import JsonDict, StrCollection, get_domain_from_id
from synapse.types.state import StateFilter
@@ -990,7 +990,10 @@ class FederationHandler:
)
try:
- event, context = await self.event_creation_handler.create_new_client_event(
+ (
+ event,
+ unpersisted_context,
+ ) = await self.event_creation_handler.create_new_client_event(
builder=builder
)
except SynapseError as e:
@@ -998,7 +1001,9 @@ class FederationHandler:
raise
# Ensure the user can even join the room.
- await self._federation_event_handler.check_join_restrictions(context, event)
+ await self._federation_event_handler.check_join_restrictions(
+ unpersisted_context, event
+ )
# The remote hasn't signed it yet, obviously. We'll do the full checks
# when we get the event back in `on_send_join_request`
@@ -1178,7 +1183,7 @@ class FederationHandler:
},
)
- event, context = await self.event_creation_handler.create_new_client_event(
+ event, _ = await self.event_creation_handler.create_new_client_event(
builder=builder
)
@@ -1228,12 +1233,13 @@ class FederationHandler:
},
)
- event, context = await self.event_creation_handler.create_new_client_event(
- builder=builder
- )
+ (
+ event,
+ unpersisted_context,
+ ) = await self.event_creation_handler.create_new_client_event(builder=builder)
event_allowed, _ = await self.third_party_event_rules.check_event_allowed(
- event, context
+ event, unpersisted_context
)
if not event_allowed:
logger.warning("Creation of knock %s forbidden by third-party rules", event)
@@ -1406,15 +1412,20 @@ class FederationHandler:
try:
(
event,
- context,
+ unpersisted_context,
) = await self.event_creation_handler.create_new_client_event(
builder=builder
)
- event, context = await self.add_display_name_to_third_party_invite(
- room_version_obj, event_dict, event, context
+ (
+ event,
+ unpersisted_context,
+ ) = await self.add_display_name_to_third_party_invite(
+ room_version_obj, event_dict, event, unpersisted_context
)
+ context = await unpersisted_context.persist(event)
+
EventValidator().validate_new(event, self.config)
# We need to tell the transaction queue to send this out, even
@@ -1483,14 +1494,19 @@ class FederationHandler:
try:
(
event,
- context,
+ unpersisted_context,
) = await self.event_creation_handler.create_new_client_event(
builder=builder
)
- event, context = await self.add_display_name_to_third_party_invite(
- room_version_obj, event_dict, event, context
+ (
+ event,
+ unpersisted_context,
+ ) = await self.add_display_name_to_third_party_invite(
+ room_version_obj, event_dict, event, unpersisted_context
)
+ context = await unpersisted_context.persist(event)
+
try:
validate_event_for_room_version(event)
await self._event_auth_handler.check_auth_rules_from_context(event)
@@ -1522,8 +1538,8 @@ class FederationHandler:
room_version_obj: RoomVersion,
event_dict: JsonDict,
event: EventBase,
- context: EventContext,
- ) -> Tuple[EventBase, EventContext]:
+ context: UnpersistedEventContextBase,
+ ) -> Tuple[EventBase, UnpersistedEventContextBase]:
key = (
EventTypes.ThirdPartyInvite,
event.content["third_party_invite"]["signed"]["token"],
@@ -1557,11 +1573,14 @@ class FederationHandler:
room_version_obj, event_dict
)
EventValidator().validate_builder(builder)
- event, context = await self.event_creation_handler.create_new_client_event(
- builder=builder
- )
+
+ (
+ event,
+ unpersisted_context,
+ ) = await self.event_creation_handler.create_new_client_event(builder=builder)
+
EventValidator().validate_new(event, self.config)
- return event, context
+ return event, unpersisted_context
async def _check_signature(self, event: EventBase, context: EventContext) -> None:
"""
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index e037acbca2..b7136f8d1c 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -47,6 +47,7 @@ from synapse.api.errors import (
FederationError,
FederationPullAttemptBackoffError,
HttpResponseException,
+ PartialStateConflictError,
RequestSendFailed,
SynapseError,
)
@@ -58,7 +59,7 @@ from synapse.event_auth import (
validate_event_for_room_version,
)
from synapse.events import EventBase
-from synapse.events.snapshot import EventContext
+from synapse.events.snapshot import EventContext, UnpersistedEventContextBase
from synapse.federation.federation_client import InvalidResponseError, PulledPduInfo
from synapse.logging.context import nested_logging_context
from synapse.logging.opentracing import (
@@ -74,7 +75,6 @@ from synapse.replication.http.federation import (
ReplicationFederationSendEventsRestServlet,
)
from synapse.state import StateResolutionStore
-from synapse.storage.databases.main.events import PartialStateConflictError
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.types import (
PersistedEventPosition,
@@ -426,7 +426,9 @@ class FederationEventHandler:
return event, context
async def check_join_restrictions(
- self, context: EventContext, event: EventBase
+ self,
+ context: UnpersistedEventContextBase,
+ event: EventBase,
) -> None:
"""Check that restrictions in restricted join rules are matched
@@ -439,16 +441,17 @@ class FederationEventHandler:
# Check if the user is already in the room or invited to the room.
user_id = event.state_key
prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
- prev_member_event = None
+ prev_membership = None
if prev_member_event_id:
prev_member_event = await self._store.get_event(prev_member_event_id)
+ prev_membership = prev_member_event.membership
# Check if the member should be allowed access via membership in a space.
await self._event_auth_handler.check_restricted_join_rules(
prev_state_ids,
event.room_version,
user_id,
- prev_member_event,
+ prev_membership,
)
@trace
@@ -524,11 +527,57 @@ class FederationEventHandler:
"Peristing join-via-remote %s (partial_state: %s)", event, partial_state
)
with nested_logging_context(suffix=event.event_id):
+ if partial_state:
+ # When handling a second partial state join into a partial state room,
+ # the returned state will exclude the membership from the first join. To
+ # preserve prior memberships, we try to compute the partial state before
+ # the event ourselves if we know about any of the prev events.
+ #
+ # When we don't know about any of the prev events, it's fine to just use
+ # the returned state, since the new join will create a new forward
+ # extremity, and leave the forward extremity containing our prior
+ # memberships alone.
+ prev_event_ids = set(event.prev_event_ids())
+ seen_event_ids = await self._store.have_events_in_timeline(
+ prev_event_ids
+ )
+ missing_event_ids = prev_event_ids - seen_event_ids
+
+ state_maps_to_resolve: List[StateMap[str]] = []
+
+ # Fetch the state after the prev events that we know about.
+ state_maps_to_resolve.extend(
+ (
+ await self._state_storage_controller.get_state_groups_ids(
+ room_id, seen_event_ids, await_full_state=False
+ )
+ ).values()
+ )
+
+ # When there are prev events we do not have the state for, we state
+ # resolve with the state returned by the remote homeserver.
+ if missing_event_ids or len(state_maps_to_resolve) == 0:
+ state_maps_to_resolve.append(
+ {(e.type, e.state_key): e.event_id for e in state}
+ )
+
+ state_ids_before_event = (
+ await self._state_resolution_handler.resolve_events_with_store(
+ event.room_id,
+ room_version.identifier,
+ state_maps_to_resolve,
+ event_map=None,
+ state_res_store=StateResolutionStore(self._store),
+ )
+ )
+ else:
+ state_ids_before_event = {
+ (e.type, e.state_key): e.event_id for e in state
+ }
+
context = await self._state_handler.compute_event_context(
event,
- state_ids_before_event={
- (e.type, e.state_key): e.event_id for e in state
- },
+ state_ids_before_event=state_ids_before_event,
partial_state=partial_state,
)
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 848e46eb9b..bf0f7acf80 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -219,28 +219,31 @@ class IdentityHandler:
data = json_decoder.decode(e.msg) # XXX WAT?
return data
- async def try_unbind_threepid(self, mxid: str, threepid: dict) -> bool:
- """Attempt to remove a 3PID from an identity server, or if one is not provided, all
- identity servers we're aware the binding is present on
+ async def try_unbind_threepid(
+ self, mxid: str, medium: str, address: str, id_server: Optional[str]
+ ) -> bool:
+ """Attempt to remove a 3PID from one or more identity servers.
Args:
mxid: Matrix user ID of binding to be removed
- threepid: Dict with medium & address of binding to be
- removed, and an optional id_server.
+ medium: The medium of the third-party ID.
+ address: The address of the third-party ID.
+ id_server: An identity server to attempt to unbind from. If None,
+ attempt to remove the association from all identity servers
+ known to potentially have it.
Raises:
- SynapseError: If we failed to contact the identity server
+ SynapseError: If we failed to contact one or more identity servers.
Returns:
- True on success, otherwise False if the identity
- server doesn't support unbinding (or no identity server found to
- contact).
+ True on success, otherwise False if the identity server doesn't
+ support unbinding (or no identity server to contact was found).
"""
- if threepid.get("id_server"):
- id_servers = [threepid["id_server"]]
+ if id_server:
+ id_servers = [id_server]
else:
id_servers = await self.store.get_id_servers_user_bound(
- user_id=mxid, medium=threepid["medium"], address=threepid["address"]
+ mxid, medium, address
)
# We don't know where to unbind, so we don't have a choice but to return
@@ -249,20 +252,21 @@ class IdentityHandler:
changed = True
for id_server in id_servers:
- changed &= await self.try_unbind_threepid_with_id_server(
- mxid, threepid, id_server
+ changed &= await self._try_unbind_threepid_with_id_server(
+ mxid, medium, address, id_server
)
return changed
- async def try_unbind_threepid_with_id_server(
- self, mxid: str, threepid: dict, id_server: str
+ async def _try_unbind_threepid_with_id_server(
+ self, mxid: str, medium: str, address: str, id_server: str
) -> bool:
"""Removes a binding from an identity server
Args:
mxid: Matrix user ID of binding to be removed
- threepid: Dict with medium & address of binding to be removed
+ medium: The medium of the third-party ID
+ address: The address of the third-party ID
id_server: Identity server to unbind from
Raises:
@@ -286,7 +290,7 @@ class IdentityHandler:
content = {
"mxid": mxid,
- "threepid": {"medium": threepid["medium"], "address": threepid["address"]},
+ "threepid": {"medium": medium, "address": address},
}
# we abuse the federation http client to sign the request, but we have to send it
@@ -319,12 +323,7 @@ class IdentityHandler:
except RequestTimedOutError:
raise SynapseError(500, "Timed out contacting identity server")
- await self.store.remove_user_bound_threepid(
- user_id=mxid,
- medium=threepid["medium"],
- address=threepid["address"],
- id_server=id_server,
- )
+ await self.store.remove_user_bound_threepid(mxid, medium, address, id_server)
return changed
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 191529bd8e..1a29abde98 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -154,9 +154,8 @@ class InitialSyncHandler:
tags_by_room = await self.store.get_tags_for_user(user_id)
- account_data, account_data_by_room = await self.store.get_account_data_for_user(
- user_id
- )
+ account_data = await self.store.get_global_account_data_for_user(user_id)
+ account_data_by_room = await self.store.get_room_account_data_for_user(user_id)
public_room_ids = await self.store.get_public_room_ids()
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index e688e00575..8f5b658d9d 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -38,6 +38,7 @@ from synapse.api.errors import (
Codes,
ConsentNotGivenError,
NotFoundError,
+ PartialStateConflictError,
ShadowBanError,
SynapseError,
UnstableSpecAuthError,
@@ -48,7 +49,7 @@ from synapse.api.urls import ConsentURIBuilder
from synapse.event_auth import validate_event_for_room_version
from synapse.events import EventBase, relation_from_event
from synapse.events.builder import EventBuilder
-from synapse.events.snapshot import EventContext
+from synapse.events.snapshot import EventContext, UnpersistedEventContextBase
from synapse.events.utils import maybe_upsert_event_field
from synapse.events.validator import EventValidator
from synapse.handlers.directory import DirectoryHandler
@@ -57,7 +58,6 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http.send_event import ReplicationSendEventRestServlet
from synapse.replication.http.send_events import ReplicationSendEventsRestServlet
-from synapse.storage.databases.main.events import PartialStateConflictError
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.types import (
MutableStateMap,
@@ -499,9 +499,9 @@ class EventCreationHandler:
self.request_ratelimiter = hs.get_request_ratelimiter()
- # We arbitrarily limit concurrent event creation for a room to 5.
- # This is to stop us from diverging history *too* much.
- self.limiter = Linearizer(max_count=5, name="room_event_creation_limit")
+ # We limit concurrent event creation for a room to 1. This prevents state resolution
+ # from occurring when sending bursts of events to a local room
+ self.limiter = Linearizer(max_count=1, name="room_event_creation_limit")
self._bulk_push_rule_evaluator = hs.get_bulk_push_rule_evaluator()
@@ -708,7 +708,7 @@ class EventCreationHandler:
builder.internal_metadata.historical = historical
- event, context = await self.create_new_client_event(
+ event, unpersisted_context = await self.create_new_client_event(
builder=builder,
requester=requester,
allow_no_prev_events=allow_no_prev_events,
@@ -721,6 +721,8 @@ class EventCreationHandler:
current_state_group=current_state_group,
)
+ context = await unpersisted_context.persist(event)
+
# In an ideal world we wouldn't need the second part of this condition. However,
# this behaviour isn't spec'd yet, meaning we should be able to deactivate this
# behaviour. Another reason is that this code is also evaluated each time a new
@@ -1083,13 +1085,14 @@ class EventCreationHandler:
state_map: Optional[StateMap[str]] = None,
for_batch: bool = False,
current_state_group: Optional[int] = None,
- ) -> Tuple[EventBase, EventContext]:
+ ) -> Tuple[EventBase, UnpersistedEventContextBase]:
"""Create a new event for a local client. If bool for_batch is true, will
create an event using the prev_event_ids, and will create an event context for
the event using the parameters state_map and current_state_group, thus these parameters
must be provided in this case if for_batch is True. The subsequently created event
and context are suitable for being batched up and bulk persisted to the database
- with other similarly created events.
+ with other similarly created events. Note that this returns an UnpersistedEventContext,
+ which must be converted to an EventContext before it can be sent to the DB.
Args:
builder:
@@ -1131,7 +1134,7 @@ class EventCreationHandler:
batch persisting
Returns:
- Tuple of created event, context
+ Tuple of created event, UnpersistedEventContext
"""
# Strip down the state_event_ids to only what we need to auth the event.
# For example, we don't need extra m.room.member that don't match event.sender
@@ -1192,9 +1195,16 @@ class EventCreationHandler:
event = await builder.build(
prev_event_ids=prev_event_ids, auth_event_ids=auth_ids, depth=depth
)
- context = await self.state.compute_event_context_for_batched(
- event, state_map, current_state_group
+
+ context: UnpersistedEventContextBase = (
+ await self.state.calculate_context_info(
+ event,
+ state_ids_before_event=state_map,
+ partial_state=False,
+ state_group_before_event=current_state_group,
+ )
)
+
else:
event = await builder.build(
prev_event_ids=prev_event_ids,
@@ -1244,16 +1254,17 @@ class EventCreationHandler:
state_map_for_event[(data.event_type, data.state_key)] = state_id
- context = await self.state.compute_event_context(
+ # TODO(faster_joins): check how MSC2716 works and whether we can have
+ # partial state here
+ # https://github.com/matrix-org/synapse/issues/13003
+ context = await self.state.calculate_context_info(
event,
state_ids_before_event=state_map_for_event,
- # TODO(faster_joins): check how MSC2716 works and whether we can have
- # partial state here
- # https://github.com/matrix-org/synapse/issues/13003
partial_state=False,
)
+
else:
- context = await self.state.compute_event_context(event)
+ context = await self.state.calculate_context_info(event)
if requester:
context.app_service = requester.app_service
@@ -2082,9 +2093,9 @@ class EventCreationHandler:
async def _rebuild_event_after_third_party_rules(
self, third_party_result: dict, original_event: EventBase
- ) -> Tuple[EventBase, EventContext]:
+ ) -> Tuple[EventBase, UnpersistedEventContextBase]:
# the third_party_event_rules want to replace the event.
- # we do some basic checks, and then return the replacement event and context.
+ # we do some basic checks, and then return the replacement event.
# Construct a new EventBuilder and validate it, which helps with the
# rest of these checks.
@@ -2138,5 +2149,6 @@ class EventCreationHandler:
# we rebuild the event context, to be on the safe side. If nothing else,
# delta_ids might need an update.
- context = await self.state.compute_event_context(event)
+ context = await self.state.calculate_context_info(event)
+
return event, context
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 04c61ae3dd..2bacdebfb5 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence, Tuple
from synapse.api.constants import EduTypes, ReceiptTypes
from synapse.appservice import ApplicationService
@@ -189,7 +189,7 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
@staticmethod
def filter_out_private_receipts(
- rooms: List[JsonDict], user_id: str
+ rooms: Sequence[JsonDict], user_id: str
) -> List[JsonDict]:
"""
Filters a list of serialized receipts (as returned by /sync and /initialSync)
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 7ba7c4ff07..837dabb3b7 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -43,6 +43,7 @@ from synapse.api.errors import (
Codes,
LimitExceededError,
NotFoundError,
+ PartialStateConflictError,
StoreError,
SynapseError,
)
@@ -54,7 +55,6 @@ from synapse.events.utils import copy_and_fixup_power_levels_contents
from synapse.handlers.relations import BundledAggregations
from synapse.module_api import NOT_SPAM
from synapse.rest.admin._base import assert_user_is_admin
-from synapse.storage.databases.main.events import PartialStateConflictError
from synapse.streams import EventSource
from synapse.types import (
JsonDict,
@@ -1076,7 +1076,7 @@ class RoomCreationHandler:
state_map: MutableStateMap[str] = {}
# current_state_group of last event created. Used for computing event context of
# events to be batched
- current_state_group = None
+ current_state_group: Optional[int] = None
def create_event_dict(etype: str, content: JsonDict, **kwargs: Any) -> JsonDict:
e = {"type": etype, "content": content}
@@ -1928,6 +1928,6 @@ class RoomShutdownHandler:
return {
"kicked_users": kicked_users,
"failed_to_kick_users": failed_to_kick_users,
- "local_aliases": aliases_for_room,
+ "local_aliases": list(aliases_for_room),
"new_room_id": new_room_id,
}
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index d236cc09b5..a965c7ec76 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -26,7 +26,13 @@ from synapse.api.constants import (
GuestAccess,
Membership,
)
-from synapse.api.errors import AuthError, Codes, ShadowBanError, SynapseError
+from synapse.api.errors import (
+ AuthError,
+ Codes,
+ PartialStateConflictError,
+ ShadowBanError,
+ SynapseError,
+)
from synapse.api.ratelimiting import Ratelimiter
from synapse.event_auth import get_named_level, get_power_level_event
from synapse.events import EventBase
@@ -34,7 +40,6 @@ from synapse.events.snapshot import EventContext
from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN
from synapse.logging import opentracing
from synapse.module_api import NOT_SPAM
-from synapse.storage.databases.main.events import PartialStateConflictError
from synapse.types import (
JsonDict,
Requester,
@@ -56,6 +61,13 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+class NoKnownServersError(SynapseError):
+ """No server already resident to the room was provided to the join/knock operation."""
+
+ def __init__(self, msg: str = "No known servers"):
+ super().__init__(404, msg)
+
+
class RoomMemberHandler(metaclass=abc.ABCMeta):
# TODO(paul): This handler currently contains a messy conflation of
# low-level API that works on UserID objects and so on, and REST-level
@@ -185,6 +197,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
room_id: Room that we are trying to join
user: User who is trying to join
content: A dict that should be used as the content of the join event.
+
+ Raises:
+ NoKnownServersError: if remote_room_hosts does not contain a server joined to
+ the room.
"""
raise NotImplementedError()
@@ -484,7 +500,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
user_id: The user's ID.
"""
# Retrieve user account data for predecessor room
- user_account_data, _ = await self.store.get_account_data_for_user(user_id)
+ user_account_data = await self.store.get_global_account_data_for_user(user_id)
# Copy direct message state if applicable
direct_rooms = user_account_data.get(AccountDataTypes.DIRECT, {})
@@ -823,14 +839,19 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
latest_event_ids = await self.store.get_prev_events_for_room(room_id)
- state_before_join = await self.state_handler.compute_state_after_events(
- room_id, latest_event_ids
+ is_partial_state_room = await self.store.is_partial_state_room(room_id)
+ partial_state_before_join = await self.state_handler.compute_state_after_events(
+ room_id, latest_event_ids, await_full_state=False
)
+ # `is_partial_state_room` also indicates whether `partial_state_before_join` is
+ # partial.
# TODO: Refactor into dictionary of explicitly allowed transitions
# between old and new state, with specific error messages for some
# transitions and generic otherwise
- old_state_id = state_before_join.get((EventTypes.Member, target.to_string()))
+ old_state_id = partial_state_before_join.get(
+ (EventTypes.Member, target.to_string())
+ )
if old_state_id:
old_state = await self.store.get_event(old_state_id, allow_none=True)
old_membership = old_state.content.get("membership") if old_state else None
@@ -881,11 +902,11 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
if action == "kick":
raise AuthError(403, "The target user is not in the room")
- is_host_in_room = await self._is_host_in_room(state_before_join)
+ is_host_in_room = await self._is_host_in_room(partial_state_before_join)
if effective_membership_state == Membership.JOIN:
if requester.is_guest:
- guest_can_join = await self._can_guest_join(state_before_join)
+ guest_can_join = await self._can_guest_join(partial_state_before_join)
if not guest_can_join:
# This should be an auth check, but guests are a local concept,
# so don't really fit into the general auth process.
@@ -927,8 +948,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
room_id,
remote_room_hosts,
content,
+ is_partial_state_room,
is_host_in_room,
- state_before_join,
+ partial_state_before_join,
)
if remote_join:
if ratelimit:
@@ -1073,8 +1095,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
room_id: str,
remote_room_hosts: List[str],
content: JsonDict,
+ is_partial_state_room: bool,
is_host_in_room: bool,
- state_before_join: StateMap[str],
+ partial_state_before_join: StateMap[str],
) -> Tuple[bool, List[str]]:
"""
Check whether the server should do a remote join (as opposed to a local
@@ -1093,9 +1116,12 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
remote_room_hosts: A list of remote room hosts.
content: The content to use as the event body of the join. This may
be modified.
- is_host_in_room: True if the host is in the room.
- state_before_join: The state before the join event (i.e. the resolution of
- the states after its parent events).
+ is_partial_state_room: `True` if the server currently doesn't hold the full
+ state of the room.
+ is_host_in_room: `True` if the host is in the room.
+ partial_state_before_join: The state before the join event (i.e. the
+ resolution of the states after its parent events). May be full or
+ partial state, depending on `is_partial_state_room`.
Returns:
A tuple of:
@@ -1109,6 +1135,23 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
if not is_host_in_room:
return True, remote_room_hosts
+ prev_member_event_id = partial_state_before_join.get(
+ (EventTypes.Member, user_id), None
+ )
+ previous_membership = None
+ if prev_member_event_id:
+ prev_member_event = await self.store.get_event(prev_member_event_id)
+ previous_membership = prev_member_event.membership
+
+ # If we are not fully joined yet, and the target is not already in the room,
+ # let's do a remote join so another server with the full state can validate
+ # that the user has not been banned for example.
+ # We could just accept the join and wait for state res to resolve that later on
+ # but we would then leak room history to this person until then, which is pretty
+ # bad.
+ if is_partial_state_room and previous_membership != Membership.JOIN:
+ return True, remote_room_hosts
+
# If the host is in the room, but not one of the authorised hosts
# for restricted join rules, a remote join must be used.
room_version = await self.store.get_room_version(room_id)
@@ -1116,21 +1159,19 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# If restricted join rules are not being used, a local join can always
# be used.
if not await self.event_auth_handler.has_restricted_join_rules(
- state_before_join, room_version
+ partial_state_before_join, room_version
):
return False, []
# If the user is invited to the room or already joined, the join
# event can always be issued locally.
- prev_member_event_id = state_before_join.get((EventTypes.Member, user_id), None)
- prev_member_event = None
- if prev_member_event_id:
- prev_member_event = await self.store.get_event(prev_member_event_id)
- if prev_member_event.membership in (
- Membership.JOIN,
- Membership.INVITE,
- ):
- return False, []
+ if previous_membership in (Membership.JOIN, Membership.INVITE):
+ return False, []
+
+ # All the partial state cases are covered above. We have been given the full
+ # state of the room.
+ assert not is_partial_state_room
+ state_before_join = partial_state_before_join
# If the local host has a user who can issue invites, then a local
# join can be done.
@@ -1154,7 +1195,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# Ensure the member should be allowed access via membership in a room.
await self.event_auth_handler.check_restricted_join_rules(
- state_before_join, room_version, user_id, prev_member_event
+ state_before_join, room_version, user_id, previous_membership
)
# If this is going to be a local join, additional information must
@@ -1304,11 +1345,17 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
if prev_member_event.membership == Membership.JOIN:
await self._user_left_room(target_user, room_id)
- async def _can_guest_join(self, current_state_ids: StateMap[str]) -> bool:
+ async def _can_guest_join(self, partial_current_state_ids: StateMap[str]) -> bool:
"""
Returns whether a guest can join a room based on its current state.
+
+ Args:
+ partial_current_state_ids: The current state of the room. May be full or
+ partial state.
"""
- guest_access_id = current_state_ids.get((EventTypes.GuestAccess, ""), None)
+ guest_access_id = partial_current_state_ids.get(
+ (EventTypes.GuestAccess, ""), None
+ )
if not guest_access_id:
return False
@@ -1634,19 +1681,25 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
)
return event, stream_id
- async def _is_host_in_room(self, current_state_ids: StateMap[str]) -> bool:
+ async def _is_host_in_room(self, partial_current_state_ids: StateMap[str]) -> bool:
+ """Returns whether the homeserver is in the room based on its current state.
+
+ Args:
+ partial_current_state_ids: The current state of the room. May be full or
+ partial state.
+ """
# Have we just created the room, and is this about to be the very
# first member event?
- create_event_id = current_state_ids.get(("m.room.create", ""))
- if len(current_state_ids) == 1 and create_event_id:
+ create_event_id = partial_current_state_ids.get(("m.room.create", ""))
+ if len(partial_current_state_ids) == 1 and create_event_id:
# We can only get here if we're in the process of creating the room
return True
- for etype, state_key in current_state_ids:
+ for etype, state_key in partial_current_state_ids:
if etype != EventTypes.Member or not self.hs.is_mine_id(state_key):
continue
- event_id = current_state_ids[(etype, state_key)]
+ event_id = partial_current_state_ids[(etype, state_key)]
event = await self.store.get_event(event_id, allow_none=True)
if not event:
continue
@@ -1715,8 +1768,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
]
if len(remote_room_hosts) == 0:
- raise SynapseError(
- 404,
+ raise NoKnownServersError(
"Can't join remote room because no servers "
"that are in the room have been provided.",
)
@@ -1947,7 +1999,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
]
if len(remote_room_hosts) == 0:
- raise SynapseError(404, "No known servers")
+ raise NoKnownServersError()
return await self.federation_handler.do_knock(
remote_room_hosts, room_id, user.to_string(), content=content
diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py
index 221552a2a6..ba261702d4 100644
--- a/synapse/handlers/room_member_worker.py
+++ b/synapse/handlers/room_member_worker.py
@@ -15,8 +15,7 @@
import logging
from typing import TYPE_CHECKING, List, Optional, Tuple
-from synapse.api.errors import SynapseError
-from synapse.handlers.room_member import RoomMemberHandler
+from synapse.handlers.room_member import NoKnownServersError, RoomMemberHandler
from synapse.replication.http.membership import (
ReplicationRemoteJoinRestServlet as ReplRemoteJoin,
ReplicationRemoteKnockRestServlet as ReplRemoteKnock,
@@ -52,7 +51,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
) -> Tuple[str, int]:
"""Implements RoomMemberHandler._remote_join"""
if len(remote_room_hosts) == 0:
- raise SynapseError(404, "No known servers")
+ raise NoKnownServersError()
ret = await self._remote_join_client(
requester=requester,
diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py
index 4472019fbc..807245160d 100644
--- a/synapse/handlers/room_summary.py
+++ b/synapse/handlers/room_summary.py
@@ -521,8 +521,8 @@ class RoomSummaryHandler:
It should return true if:
- * The requester is joined or can join the room (per MSC3173).
- * The origin server has any user that is joined or can join the room.
+ * The requesting user is joined or can join the room (per MSC3173); or
+ * The origin server has any user that is joined or can join the room; or
* The history visibility is set to world readable.
Args:
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 3566537894..4e4595312c 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -269,6 +269,8 @@ class SyncHandler:
self._state_storage_controller = self._storage_controllers.state
self._device_handler = hs.get_device_handler()
+ self.should_calculate_push_rules = hs.config.push.enable_push
+
# TODO: flush cache entries on subsequent sync request.
# Once we get the next /sync request (ie, one with the same access token
# that sets 'since' to 'next_batch'), we know that device won't need a
@@ -1288,6 +1290,12 @@ class SyncHandler:
async def unread_notifs_for_room_id(
self, room_id: str, sync_config: SyncConfig
) -> RoomNotifCounts:
+ if not self.should_calculate_push_rules:
+ # If push rules have been universally disabled then we know we won't
+ # have any unread counts in the DB, so we may as well skip asking
+ # the DB.
+ return RoomNotifCounts.empty()
+
with Measure(self.clock, "unread_notifs_for_room_id"):
return await self.store.get_unread_event_push_actions_by_room_for_user(
@@ -1391,6 +1399,11 @@ class SyncHandler:
for room_id, is_partial_state in results.items()
if is_partial_state
)
+ membership_change_events = [
+ event
+ for event in membership_change_events
+ if not results.get(event.room_id, False)
+ ]
# Incremental eager syncs should additionally include rooms that
# - we are joined to
@@ -1444,9 +1457,9 @@ class SyncHandler:
logger.debug("Fetching account data")
- account_data_by_room = await self._generate_sync_entry_for_account_data(
- sync_result_builder
- )
+ # Global account data is included if it is not filtered out.
+ if not sync_config.filter_collection.blocks_all_global_account_data():
+ await self._generate_sync_entry_for_account_data(sync_result_builder)
# Presence data is included if the server has it enabled and not filtered out.
include_presence_data = bool(
@@ -1472,9 +1485,7 @@ class SyncHandler:
(
newly_joined_rooms,
newly_left_rooms,
- ) = await self._generate_sync_entry_for_rooms(
- sync_result_builder, account_data_by_room
- )
+ ) = await self._generate_sync_entry_for_rooms(sync_result_builder)
# Work out which users have joined or left rooms we're in. We use this
# to build the presence and device_list parts of the sync response in
@@ -1521,7 +1532,7 @@ class SyncHandler:
one_time_keys_count = await self.store.count_e2e_one_time_keys(
user_id, device_id
)
- unused_fallback_key_types = (
+ unused_fallback_key_types = list(
await self.store.get_e2e_unused_fallback_key_types(user_id, device_id)
)
@@ -1717,35 +1728,29 @@ class SyncHandler:
async def _generate_sync_entry_for_account_data(
self, sync_result_builder: "SyncResultBuilder"
- ) -> Dict[str, Dict[str, JsonDict]]:
- """Generates the account data portion of the sync response.
+ ) -> None:
+ """Generates the global account data portion of the sync response.
Account data (called "Client Config" in the spec) can be set either globally
or for a specific room. Account data consists of a list of events which
accumulate state, much like a room.
- This function retrieves global and per-room account data. The former is written
- to the given `sync_result_builder`. The latter is returned directly, to be
- later written to the `sync_result_builder` on a room-by-room basis.
+ This function retrieves global account data and writes it to the given
+ `sync_result_builder`. See `_generate_sync_entry_for_rooms` for handling
+ of per-room account data.
Args:
sync_result_builder
-
- Returns:
- A dictionary whose keys (room ids) map to the per room account data for that
- room.
"""
sync_config = sync_result_builder.sync_config
user_id = sync_result_builder.sync_config.user.to_string()
since_token = sync_result_builder.since_token
if since_token and not sync_result_builder.full_state:
- # TODO Do not fetch room account data if it will be unused.
- (
- global_account_data,
- account_data_by_room,
- ) = await self.store.get_updated_account_data_for_user(
- user_id, since_token.account_data_key
+ global_account_data = (
+ await self.store.get_updated_global_account_data_for_user(
+ user_id, since_token.account_data_key
+ )
)
push_rules_changed = await self.store.have_push_rules_changed_for_user(
@@ -1753,31 +1758,31 @@ class SyncHandler:
)
if push_rules_changed:
+ global_account_data = dict(global_account_data)
global_account_data["m.push_rules"] = await self.push_rules_for_user(
sync_config.user
)
else:
- # TODO Do not fetch room account data if it will be unused.
- (
- global_account_data,
- account_data_by_room,
- ) = await self.store.get_account_data_for_user(sync_config.user.to_string())
+ all_global_account_data = await self.store.get_global_account_data_for_user(
+ user_id
+ )
+ global_account_data = dict(all_global_account_data)
global_account_data["m.push_rules"] = await self.push_rules_for_user(
sync_config.user
)
- account_data_for_user = await sync_config.filter_collection.filter_account_data(
- [
- {"type": account_data_type, "content": content}
- for account_data_type, content in global_account_data.items()
- ]
+ account_data_for_user = (
+ await sync_config.filter_collection.filter_global_account_data(
+ [
+ {"type": account_data_type, "content": content}
+ for account_data_type, content in global_account_data.items()
+ ]
+ )
)
sync_result_builder.account_data = account_data_for_user
- return account_data_by_room
-
async def _generate_sync_entry_for_presence(
self,
sync_result_builder: "SyncResultBuilder",
@@ -1837,9 +1842,7 @@ class SyncHandler:
sync_result_builder.presence = presence
async def _generate_sync_entry_for_rooms(
- self,
- sync_result_builder: "SyncResultBuilder",
- account_data_by_room: Dict[str, Dict[str, JsonDict]],
+ self, sync_result_builder: "SyncResultBuilder"
) -> Tuple[AbstractSet[str], AbstractSet[str]]:
"""Generates the rooms portion of the sync response. Populates the
`sync_result_builder` with the result.
@@ -1850,7 +1853,6 @@ class SyncHandler:
Args:
sync_result_builder
- account_data_by_room: Dictionary of per room account data
Returns:
Returns a 2-tuple describing rooms the user has joined or left.
@@ -1863,9 +1865,30 @@ class SyncHandler:
since_token = sync_result_builder.since_token
user_id = sync_result_builder.sync_config.user.to_string()
+ blocks_all_rooms = (
+ sync_result_builder.sync_config.filter_collection.blocks_all_rooms()
+ )
+
+ # 0. Start by fetching room account data (if required).
+ if (
+ blocks_all_rooms
+ or sync_result_builder.sync_config.filter_collection.blocks_all_room_account_data()
+ ):
+ account_data_by_room: Mapping[str, Mapping[str, JsonDict]] = {}
+ elif since_token and not sync_result_builder.full_state:
+ account_data_by_room = (
+ await self.store.get_updated_room_account_data_for_user(
+ user_id, since_token.account_data_key
+ )
+ )
+ else:
+ account_data_by_room = await self.store.get_room_account_data_for_user(
+ user_id
+ )
+
# 1. Start by fetching all ephemeral events in rooms we've joined (if required).
block_all_room_ephemeral = (
- sync_result_builder.sync_config.filter_collection.blocks_all_rooms()
+ blocks_all_rooms
or sync_result_builder.sync_config.filter_collection.blocks_all_room_ephemeral()
)
if block_all_room_ephemeral:
@@ -2291,8 +2314,8 @@ class SyncHandler:
sync_result_builder: "SyncResultBuilder",
room_builder: "RoomSyncResultBuilder",
ephemeral: List[JsonDict],
- tags: Optional[Dict[str, Dict[str, Any]]],
- account_data: Dict[str, JsonDict],
+ tags: Optional[Mapping[str, Mapping[str, Any]]],
+ account_data: Mapping[str, JsonDict],
always_include: bool = False,
) -> None:
"""Populates the `joined` and `archived` section of `sync_result_builder`
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 2563858f3c..9314454af1 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -30,7 +30,6 @@ from typing import (
Iterable,
Iterator,
List,
- NoReturn,
Optional,
Pattern,
Tuple,
@@ -340,7 +339,8 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
return callback_return
- return _unrecognised_request_handler(request)
+ # A request with an unknown method (for a known endpoint) was received.
+ raise UnrecognizedRequestError(code=405)
@abc.abstractmethod
def _send_response(
@@ -396,7 +396,6 @@ class DirectServeJsonResource(_AsyncResource):
@attr.s(slots=True, frozen=True, auto_attribs=True)
class _PathEntry:
- pattern: Pattern
callback: ServletCallback
servlet_classname: str
@@ -425,13 +424,14 @@ class JsonResource(DirectServeJsonResource):
):
super().__init__(canonical_json, extract_context)
self.clock = hs.get_clock()
- self.path_regexs: Dict[bytes, List[_PathEntry]] = {}
+ # Map of path regex -> method -> callback.
+ self._routes: Dict[Pattern[str], Dict[bytes, _PathEntry]] = {}
self.hs = hs
def register_paths(
self,
method: str,
- path_patterns: Iterable[Pattern],
+ path_patterns: Iterable[Pattern[str]],
callback: ServletCallback,
servlet_classname: str,
) -> None:
@@ -455,8 +455,8 @@ class JsonResource(DirectServeJsonResource):
for path_pattern in path_patterns:
logger.debug("Registering for %s %s", method, path_pattern.pattern)
- self.path_regexs.setdefault(method_bytes, []).append(
- _PathEntry(path_pattern, callback, servlet_classname)
+ self._routes.setdefault(path_pattern, {})[method_bytes] = _PathEntry(
+ callback, servlet_classname
)
def _get_handler_for_request(
@@ -478,14 +478,17 @@ class JsonResource(DirectServeJsonResource):
# Loop through all the registered callbacks to check if the method
# and path regex match
- for path_entry in self.path_regexs.get(request_method, []):
- m = path_entry.pattern.match(request_path)
+ for path_pattern, methods in self._routes.items():
+ m = path_pattern.match(request_path)
if m:
- # We found a match!
+ # We found a matching path!
+ path_entry = methods.get(request_method)
+ if not path_entry:
+ raise UnrecognizedRequestError(code=405)
return path_entry.callback, path_entry.servlet_classname, m.groupdict()
- # Huh. No one wanted to handle that? Fiiiiiine. Send 400.
- return _unrecognised_request_handler, "unrecognised_request_handler", {}
+ # Huh. No one wanted to handle that? Fiiiiiine.
+ raise UnrecognizedRequestError(code=404)
async def _async_render(self, request: SynapseRequest) -> Tuple[int, Any]:
callback, servlet_classname, group_dict = self._get_handler_for_request(request)
@@ -567,19 +570,6 @@ class StaticResource(File):
return super().render_GET(request)
-def _unrecognised_request_handler(request: Request) -> NoReturn:
- """Request handler for unrecognised requests
-
- This is a request handler suitable for return from
- _get_handler_for_request. It actually just raises an
- UnrecognizedRequestError.
-
- Args:
- request: Unused, but passed in to match the signature of ServletCallback.
- """
- raise UnrecognizedRequestError(code=404)
-
-
class UnrecognizedRequestResource(resource.Resource):
"""
Similar to twisted.web.resource.NoResource, but returns a JSON 404 with an
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index 8ef9a0dda8..6c7cf1b294 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -466,8 +466,16 @@ def init_tracer(hs: "HomeServer") -> None:
STRIP_INSTANCE_NUMBER_SUFFIX_REGEX, "", hs.get_instance_name()
)
+ jaeger_config = hs.config.tracing.jaeger_config
+ tags = jaeger_config.setdefault("tags", {})
+
+ # tag the Synapse instance name so that it's an easy jumping
+ # off point into the logs. Can also be used to filter for an
+ # instance that is under load.
+ tags[SynapseTags.INSTANCE_NAME] = hs.get_instance_name()
+
config = JaegerConfig(
- config=hs.config.tracing.jaeger_config,
+ config=jaeger_config,
service_name=f"{hs.config.server.server_name} {instance_name_by_type}",
scope_manager=LogContextScopeManager(),
metrics_factory=PrometheusMetricsFactory(),
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index d9c0a98f44..f6a5bffb0f 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -22,6 +22,7 @@ from typing import (
List,
Mapping,
Optional,
+ Sequence,
Set,
Tuple,
Union,
@@ -43,6 +44,7 @@ from synapse.events.snapshot import EventContext
from synapse.state import POWER_KEY
from synapse.storage.databases.main.roommember import EventIdMembership
from synapse.synapse_rust.push import FilteredPushRules, PushRuleEvaluator
+from synapse.types import SimpleJsonValue
from synapse.types.state import StateFilter
from synapse.util.caches import register_cache
from synapse.util.metrics import measure_func
@@ -148,7 +150,7 @@ class BulkPushRuleEvaluator:
# little, we can skip fetching a huge number of push rules in large rooms.
# This helps make joins and leaves faster.
if event.type == EventTypes.Member:
- local_users = []
+ local_users: Sequence[str] = []
# We never notify a user about their own actions. This is enforced in
# `_action_for_event_by_user` in the loop over `rules_by_user`, but we
# do the same check here to avoid unnecessary DB queries.
@@ -183,7 +185,6 @@ class BulkPushRuleEvaluator:
if event.type == EventTypes.Member and event.membership == Membership.INVITE:
invited = event.state_key
if invited and self.hs.is_mine_id(invited) and invited not in local_users:
- local_users = list(local_users)
local_users.append(invited)
if not local_users:
@@ -256,13 +257,15 @@ class BulkPushRuleEvaluator:
return pl_event.content if pl_event else {}, sender_level
- async def _related_events(self, event: EventBase) -> Dict[str, Dict[str, str]]:
+ async def _related_events(
+ self, event: EventBase
+ ) -> Dict[str, Dict[str, SimpleJsonValue]]:
"""Fetches the related events for 'event'. Sets the im.vector.is_falling_back key if the event is from a fallback relation
Returns:
Mapping of relation type to flattened events.
"""
- related_events: Dict[str, Dict[str, str]] = {}
+ related_events: Dict[str, Dict[str, SimpleJsonValue]] = {}
if self._related_event_match_enabled:
related_event_id = event.content.get("m.relates_to", {}).get("event_id")
relation_type = event.content.get("m.relates_to", {}).get("rel_type")
@@ -271,7 +274,10 @@ class BulkPushRuleEvaluator:
related_event_id, allow_none=True
)
if related_event is not None:
- related_events[relation_type] = _flatten_dict(related_event)
+ related_events[relation_type] = _flatten_dict(
+ related_event,
+ msc3783_escape_event_match_key=self.hs.config.experimental.msc3783_escape_event_match_key,
+ )
reply_event_id = (
event.content.get("m.relates_to", {})
@@ -286,7 +292,10 @@ class BulkPushRuleEvaluator:
)
if related_event is not None:
- related_events["m.in_reply_to"] = _flatten_dict(related_event)
+ related_events["m.in_reply_to"] = _flatten_dict(
+ related_event,
+ msc3783_escape_event_match_key=self.hs.config.experimental.msc3783_escape_event_match_key,
+ )
# indicate that this is from a fallback relation.
if relation_type == "m.thread" and event.content.get(
@@ -405,7 +414,10 @@ class BulkPushRuleEvaluator:
room_mention = mentions.get("room") is True
evaluator = PushRuleEvaluator(
- _flatten_dict(event),
+ _flatten_dict(
+ event,
+ msc3783_escape_event_match_key=self.hs.config.experimental.msc3783_escape_event_match_key,
+ ),
has_mentions,
user_mentions,
room_mention,
@@ -416,6 +428,7 @@ class BulkPushRuleEvaluator:
self._related_event_match_enabled,
event.room_version.msc3931_push_features,
self.hs.config.experimental.msc1767_enabled, # MSC3931 flag
+ self.hs.config.experimental.msc3758_exact_event_match,
)
users = rules_by_user.keys()
@@ -492,13 +505,15 @@ StateGroup = Union[object, int]
def _flatten_dict(
d: Union[EventBase, Mapping[str, Any]],
prefix: Optional[List[str]] = None,
- result: Optional[Dict[str, str]] = None,
-) -> Dict[str, str]:
+ result: Optional[Dict[str, SimpleJsonValue]] = None,
+ *,
+ msc3783_escape_event_match_key: bool = False,
+) -> Dict[str, SimpleJsonValue]:
"""
Given a JSON dictionary (or event) which might contain sub dictionaries,
flatten it into a single layer dictionary by combining the keys & sub-keys.
- Any (non-dictionary), non-string value is dropped.
+ String, integer, boolean, and null values are kept. All others are dropped.
Transforms:
@@ -521,11 +536,22 @@ def _flatten_dict(
if result is None:
result = {}
for key, value in d.items():
- if isinstance(value, str):
- result[".".join(prefix + [key])] = value.lower()
+ if msc3783_escape_event_match_key:
+ # Escape periods in the key with a backslash (and backslashes with an
+ # extra backslash). This is since a period is used as a separator between
+ # nested fields.
+ key = key.replace("\\", "\\\\").replace(".", "\\.")
+
+ if isinstance(value, (bool, str)) or type(value) is int or value is None:
+ result[".".join(prefix + [key])] = value
elif isinstance(value, Mapping):
# do not set `room_version` due to recursion considerations below
- _flatten_dict(value, prefix=(prefix + [key]), result=result)
+ _flatten_dict(
+ value,
+ prefix=(prefix + [key]),
+ result=result,
+ msc3783_escape_event_match_key=msc3783_escape_event_match_key,
+ )
# `room_version` should only ever be set when looking at the top level of an event
if (
diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py
index 0d072c42a7..c134ccfb3d 100644
--- a/synapse/rest/admin/media.py
+++ b/synapse/rest/admin/media.py
@@ -15,7 +15,7 @@
import logging
from http import HTTPStatus
-from typing import TYPE_CHECKING, Tuple
+from typing import TYPE_CHECKING, Optional, Tuple
from synapse.api.constants import Direction
from synapse.api.errors import Codes, NotFoundError, SynapseError
@@ -285,7 +285,12 @@ class DeleteMediaByDateSize(RestServlet):
timestamp and size.
"""
- PATTERNS = admin_patterns("/media/(?P<server_name>[^/]*)/delete$")
+ PATTERNS = [
+ *admin_patterns("/media/delete$"),
+ # This URL kept around for legacy reasons, it is undesirable since it
+ # overlaps with the DeleteMediaByID servlet.
+ *admin_patterns("/media/(?P<server_name>[^/]*)/delete$"),
+ ]
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
@@ -294,7 +299,7 @@ class DeleteMediaByDateSize(RestServlet):
self.media_repository = hs.get_media_repository()
async def on_POST(
- self, request: SynapseRequest, server_name: str
+ self, request: SynapseRequest, server_name: Optional[str] = None
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
@@ -322,7 +327,8 @@ class DeleteMediaByDateSize(RestServlet):
errcode=Codes.INVALID_PARAM,
)
- if self.server_name != server_name:
+ # This check is useless, we keep it for the legacy endpoint only.
+ if server_name is not None and self.server_name != server_name:
raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only delete local media")
logging.info(
@@ -489,6 +495,8 @@ def register_servlets_for_media_repo(hs: "HomeServer", http_server: HttpServer)
ProtectMediaByID(hs).register(http_server)
UnprotectMediaByID(hs).register(http_server)
ListMediaInRoom(hs).register(http_server)
- DeleteMediaByID(hs).register(http_server)
+ # XXX DeleteMediaByDateSize must be registered before DeleteMediaByID as
+ # their URL routes overlap.
DeleteMediaByDateSize(hs).register(http_server)
+ DeleteMediaByID(hs).register(http_server)
UserMediaRestServlet(hs).register(http_server)
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index b9dca8ef3a..0c0bf540b9 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -1192,7 +1192,8 @@ class AccountDataRestServlet(RestServlet):
if not await self._store.get_user_by_id(user_id):
raise NotFoundError("User not found")
- global_data, by_room_data = await self._store.get_account_data_for_user(user_id)
+ global_data = await self._store.get_global_account_data_for_user(user_id)
+ by_room_data = await self._store.get_room_account_data_for_user(user_id)
return HTTPStatus.OK, {
"account_data": {
"global": global_data,
diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py
index 4373c73662..662f5bf762 100644
--- a/synapse/rest/client/account.py
+++ b/synapse/rest/client/account.py
@@ -415,6 +415,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
request, MsisdnRequestTokenBody
)
msisdn = phone_number_to_msisdn(body.country, body.phone_number)
+ logger.info("Request #%s to verify ownership of %s", body.send_attempt, msisdn)
if not await check_3pid_allowed(self.hs, "msisdn", msisdn):
raise SynapseError(
@@ -444,6 +445,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
await self.hs.get_clock().sleep(random.randint(1, 10) / 10)
return 200, {"sid": random_string(16)}
+ logger.info("MSISDN %s is already in use by %s", msisdn, existing_user_id)
raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE)
if not self.hs.config.registration.account_threepid_delegate_msisdn:
@@ -468,6 +470,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
threepid_send_requests.labels(type="msisdn", reason="add_threepid").observe(
body.send_attempt
)
+ logger.info("MSISDN %s: got response from identity server: %s", msisdn, ret)
return 200, ret
@@ -734,12 +737,7 @@ class ThreepidUnbindRestServlet(RestServlet):
# Attempt to unbind the threepid from an identity server. If id_server is None, try to
# unbind from all identity servers this threepid has been added to in the past
result = await self.identity_handler.try_unbind_threepid(
- requester.user.to_string(),
- {
- "address": body.address,
- "medium": body.medium,
- "id_server": body.id_server,
- },
+ requester.user.to_string(), body.medium, body.address, body.id_server
)
return 200, {"id_server_unbind_result": "success" if result else "no-support"}
diff --git a/synapse/rest/client/room_keys.py b/synapse/rest/client/room_keys.py
index f7081f638e..4e7ffdb555 100644
--- a/synapse/rest/client/room_keys.py
+++ b/synapse/rest/client/room_keys.py
@@ -259,6 +259,32 @@ class RoomKeysNewVersionServlet(RestServlet):
self.auth = hs.get_auth()
self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler()
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+ """
+ Retrieve the version information about the most current backup version (if any)
+
+ It takes out an exclusive lock on this user's room_key backups, to ensure
+ clients only upload to the current backup.
+
+ Returns 404 if the given version does not exist.
+
+ GET /room_keys/version HTTP/1.1
+ {
+ "version": "12345",
+ "algorithm": "m.megolm_backup.v1",
+ "auth_data": "dGhpcyBzaG91bGQgYWN0dWFsbHkgYmUgZW5jcnlwdGVkIGpzb24K"
+ }
+ """
+ requester = await self.auth.get_user_by_req(request, allow_guest=False)
+ user_id = requester.user.to_string()
+
+ try:
+ info = await self.e2e_room_keys_handler.get_version_info(user_id)
+ except SynapseError as e:
+ if e.code == 404:
+ raise SynapseError(404, "No backup found", Codes.NOT_FOUND)
+ return 200, info
+
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
"""
Create a new backup version for this user's room_keys with the given
@@ -301,7 +327,7 @@ class RoomKeysNewVersionServlet(RestServlet):
class RoomKeysVersionServlet(RestServlet):
- PATTERNS = client_patterns("/room_keys/version(/(?P<version>[^/]+))?$")
+ PATTERNS = client_patterns("/room_keys/version/(?P<version>[^/]+)$")
def __init__(self, hs: "HomeServer"):
super().__init__()
@@ -309,12 +335,11 @@ class RoomKeysVersionServlet(RestServlet):
self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler()
async def on_GET(
- self, request: SynapseRequest, version: Optional[str]
+ self, request: SynapseRequest, version: str
) -> Tuple[int, JsonDict]:
"""
Retrieve the version information about a given version of the user's
- room_keys backup. If the version part is missing, returns info about the
- most current backup version (if any)
+ room_keys backup.
It takes out an exclusive lock on this user's room_key backups, to ensure
clients only upload to the current backup.
@@ -339,20 +364,16 @@ class RoomKeysVersionServlet(RestServlet):
return 200, info
async def on_DELETE(
- self, request: SynapseRequest, version: Optional[str]
+ self, request: SynapseRequest, version: str
) -> Tuple[int, JsonDict]:
"""
Delete the information about a given version of the user's
- room_keys backup. If the version part is missing, deletes the most
- current backup version (if any). Doesn't delete the actual room data.
+ room_keys backup. Doesn't delete the actual room data.
DELETE /room_keys/version/12345 HTTP/1.1
HTTP/1.1 200 OK
{}
"""
- if version is None:
- raise SynapseError(400, "No version specified to delete", Codes.NOT_FOUND)
-
requester = await self.auth.get_user_by_req(request, allow_guest=False)
user_id = requester.user.to_string()
@@ -360,7 +381,7 @@ class RoomKeysVersionServlet(RestServlet):
return 200, {}
async def on_PUT(
- self, request: SynapseRequest, version: Optional[str]
+ self, request: SynapseRequest, version: str
) -> Tuple[int, JsonDict]:
"""
Update the information about a given version of the user's room_keys backup.
@@ -386,11 +407,6 @@ class RoomKeysVersionServlet(RestServlet):
user_id = requester.user.to_string()
info = parse_json_object_from_request(request)
- if version is None:
- raise SynapseError(
- 400, "No version specified to update", Codes.MISSING_PARAM
- )
-
await self.e2e_room_keys_handler.update_version(user_id, version, info)
return 200, {}
diff --git a/synapse/rest/client/tags.py b/synapse/rest/client/tags.py
index ca638755c7..dde08417a4 100644
--- a/synapse/rest/client/tags.py
+++ b/synapse/rest/client/tags.py
@@ -34,7 +34,9 @@ class TagListServlet(RestServlet):
GET /user/{user_id}/rooms/{room_id}/tags HTTP/1.1
"""
- PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags")
+ PATTERNS = client_patterns(
+ "/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags$"
+ )
def __init__(self, hs: "HomeServer"):
super().__init__()
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index a5c3de192f..db25848744 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -46,10 +46,9 @@ from ._base import FileInfo, Responder
from .filepath import MediaFilePaths
if TYPE_CHECKING:
+ from synapse.rest.media.v1.storage_provider import StorageProvider
from synapse.server import HomeServer
- from .storage_provider import StorageProviderWrapper
-
logger = logging.getLogger(__name__)
@@ -68,7 +67,7 @@ class MediaStorage:
hs: "HomeServer",
local_media_directory: str,
filepaths: MediaFilePaths,
- storage_providers: Sequence["StorageProviderWrapper"],
+ storage_providers: Sequence["StorageProvider"],
):
self.hs = hs
self.reactor = hs.get_reactor()
@@ -360,7 +359,7 @@ class ReadableFileWrapper:
clock: Clock
path: str
- async def write_chunks_to(self, callback: Callable[[bytes], None]) -> None:
+ async def write_chunks_to(self, callback: Callable[[bytes], object]) -> None:
"""Reads the file in chunks and calls the callback with each chunk."""
with open(self.path, "rb") as file:
diff --git a/synapse/server.py b/synapse/server.py
index 9d6d268f49..efc6b5f895 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -21,7 +21,7 @@
import abc
import functools
import logging
-from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, TypeVar, cast
+from typing import TYPE_CHECKING, Callable, Dict, List, Optional, TypeVar, cast
from twisted.internet.interfaces import IOpenSSLContextFactory
from twisted.internet.tcp import Port
@@ -144,10 +144,10 @@ if TYPE_CHECKING:
from synapse.handlers.saml import SamlHandler
-T = TypeVar("T", bound=Callable[..., Any])
+T = TypeVar("T")
-def cache_in_self(builder: T) -> T:
+def cache_in_self(builder: Callable[["HomeServer"], T]) -> Callable[["HomeServer"], T]:
"""Wraps a function called e.g. `get_foo`, checking if `self.foo` exists and
returning if so. If not, calls the given function and sets `self.foo` to it.
@@ -166,7 +166,7 @@ def cache_in_self(builder: T) -> T:
building = [False]
@functools.wraps(builder)
- def _get(self):
+ def _get(self: "HomeServer") -> T:
try:
return getattr(self, depname)
except AttributeError:
@@ -185,9 +185,7 @@ def cache_in_self(builder: T) -> T:
return dep
- # We cast here as we need to tell mypy that `_get` has the same signature as
- # `builder`.
- return cast(T, _get)
+ return _get
class HomeServer(metaclass=abc.ABCMeta):
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index fdfb46ab82..4dc25df67e 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -39,7 +39,11 @@ from prometheus_client import Counter, Histogram
from synapse.api.constants import EventTypes
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions
from synapse.events import EventBase
-from synapse.events.snapshot import EventContext
+from synapse.events.snapshot import (
+ EventContext,
+ UnpersistedEventContext,
+ UnpersistedEventContextBase,
+)
from synapse.logging.context import ContextResourceUsage
from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServlet
from synapse.state import v1, v2
@@ -222,7 +226,7 @@ class StateHandler:
return await ret.get_state(self._state_storage_controller, state_filter)
async def get_current_user_ids_in_room(
- self, room_id: str, latest_event_ids: List[str]
+ self, room_id: str, latest_event_ids: Collection[str]
) -> Set[str]:
"""
Get the users IDs who are currently in a room.
@@ -262,31 +266,31 @@ class StateHandler:
state = await entry.get_state(self._state_storage_controller, StateFilter.all())
return await self.store.get_joined_hosts(room_id, state, entry)
- async def compute_event_context(
+ async def calculate_context_info(
self,
event: EventBase,
state_ids_before_event: Optional[StateMap[str]] = None,
partial_state: Optional[bool] = None,
- ) -> EventContext:
- """Build an EventContext structure for a non-outlier event.
-
- (for an outlier, call EventContext.for_outlier directly)
-
- This works out what the current state should be for the event, and
- generates a new state group if necessary.
-
- Args:
- event:
- state_ids_before_event: The event ids of the state before the event if
- it can't be calculated from existing events. This is normally
- only specified when receiving an event from federation where we
- don't have the prev events, e.g. when backfilling.
- partial_state:
- `True` if `state_ids_before_event` is partial and omits non-critical
- membership events.
- `False` if `state_ids_before_event` is the full state.
- `None` when `state_ids_before_event` is not provided. In this case, the
- flag will be calculated based on `event`'s prev events.
+ state_group_before_event: Optional[int] = None,
+ ) -> UnpersistedEventContextBase:
+ """
+ Calulates the contents of an unpersisted event context, other than the current
+ state group (which is either provided or calculated when the event context is persisted)
+
+ state_ids_before_event:
+ The event ids of the full state before the event if
+ it can't be calculated from existing events. This is normally
+ only specified when receiving an event from federation where we
+ don't have the prev events, e.g. when backfilling or when the event
+ is being created for batch persisting.
+ partial_state:
+ `True` if `state_ids_before_event` is partial and omits non-critical
+ membership events.
+ `False` if `state_ids_before_event` is the full state.
+ `None` when `state_ids_before_event` is not provided. In this case, the
+ flag will be calculated based on `event`'s prev events.
+ state_group_before_event:
+ the current state group at the time of event, if known
Returns:
The event context.
@@ -294,7 +298,6 @@ class StateHandler:
RuntimeError if `state_ids_before_event` is not provided and one or more
prev events are missing or outliers.
"""
-
assert not event.internal_metadata.is_outlier()
#
@@ -306,17 +309,6 @@ class StateHandler:
state_group_before_event_prev_group = None
deltas_to_state_group_before_event = None
- # .. though we need to get a state group for it.
- state_group_before_event = (
- await self._state_storage_controller.store_state_group(
- event.event_id,
- event.room_id,
- prev_group=None,
- delta_ids=None,
- current_state_ids=state_ids_before_event,
- )
- )
-
# the partial_state flag must be provided
assert partial_state is not None
else:
@@ -345,6 +337,7 @@ class StateHandler:
logger.debug("calling resolve_state_groups from compute_event_context")
# we've already taken into account partial state, so no need to wait for
# complete state here.
+
entry = await self.resolve_state_groups_for_events(
event.room_id,
event.prev_event_ids(),
@@ -383,18 +376,19 @@ class StateHandler:
#
if not event.is_state():
- return EventContext.with_state(
+ return UnpersistedEventContext(
storage=self._storage_controllers,
state_group_before_event=state_group_before_event,
- state_group=state_group_before_event,
+ state_group_after_event=state_group_before_event,
state_delta_due_to_event={},
- prev_group=state_group_before_event_prev_group,
- delta_ids=deltas_to_state_group_before_event,
+ prev_group_for_state_group_before_event=state_group_before_event_prev_group,
+ delta_ids_to_state_group_before_event=deltas_to_state_group_before_event,
partial_state=partial_state,
+ state_map_before_event=state_ids_before_event,
)
#
- # otherwise, we'll need to create a new state group for after the event
+ # otherwise, we'll need to set up creating a new state group for after the event
#
key = (event.type, event.state_key)
@@ -412,88 +406,60 @@ class StateHandler:
delta_ids = {key: event.event_id}
- state_group_after_event = (
- await self._state_storage_controller.store_state_group(
- event.event_id,
- event.room_id,
- prev_group=state_group_before_event,
- delta_ids=delta_ids,
- current_state_ids=None,
- )
- )
-
- return EventContext.with_state(
+ return UnpersistedEventContext(
storage=self._storage_controllers,
- state_group=state_group_after_event,
state_group_before_event=state_group_before_event,
+ state_group_after_event=None,
state_delta_due_to_event=delta_ids,
- prev_group=state_group_before_event,
- delta_ids=delta_ids,
+ prev_group_for_state_group_before_event=state_group_before_event_prev_group,
+ delta_ids_to_state_group_before_event=deltas_to_state_group_before_event,
partial_state=partial_state,
+ state_map_before_event=state_ids_before_event,
)
- async def compute_event_context_for_batched(
+ async def compute_event_context(
self,
event: EventBase,
- state_ids_before_event: StateMap[str],
- current_state_group: int,
+ state_ids_before_event: Optional[StateMap[str]] = None,
+ partial_state: Optional[bool] = None,
) -> EventContext:
- """
- Generate an event context for an event that has not yet been persisted to the
- database. Intended for use with events that are created to be persisted in a batch.
- Args:
- event: the event the context is being computed for
- state_ids_before_event: a state map consisting of the state ids of the events
- created prior to this event.
- current_state_group: the current state group before the event.
- """
- state_group_before_event_prev_group = None
- deltas_to_state_group_before_event = None
-
- state_group_before_event = current_state_group
-
- # if the event is not state, we are set
- if not event.is_state():
- return EventContext.with_state(
- storage=self._storage_controllers,
- state_group_before_event=state_group_before_event,
- state_group=state_group_before_event,
- state_delta_due_to_event={},
- prev_group=state_group_before_event_prev_group,
- delta_ids=deltas_to_state_group_before_event,
- partial_state=False,
- )
+ """Build an EventContext structure for a non-outlier event.
- # otherwise, we'll need to create a new state group for after the event
- key = (event.type, event.state_key)
+ (for an outlier, call EventContext.for_outlier directly)
- if state_ids_before_event is not None:
- replaces = state_ids_before_event.get(key)
+ This works out what the current state should be for the event, and
+ generates a new state group if necessary.
- if replaces and replaces != event.event_id:
- event.unsigned["replaces_state"] = replaces
+ Args:
+ event:
+ state_ids_before_event: The event ids of the state before the event if
+ it can't be calculated from existing events. This is normally
+ only specified when receiving an event from federation where we
+ don't have the prev events, e.g. when backfilling.
+ partial_state:
+ `True` if `state_ids_before_event` is partial and omits non-critical
+ membership events.
+ `False` if `state_ids_before_event` is the full state.
+ `None` when `state_ids_before_event` is not provided. In this case, the
+ flag will be calculated based on `event`'s prev events.
+ entry:
+ A state cache entry for the resolved state across the prev events. We may
+ have already calculated this, so if it's available pass it in
+ Returns:
+ The event context.
- delta_ids = {key: event.event_id}
+ Raises:
+ RuntimeError if `state_ids_before_event` is not provided and one or more
+ prev events are missing or outliers.
+ """
- state_group_after_event = (
- await self._state_storage_controller.store_state_group(
- event.event_id,
- event.room_id,
- prev_group=state_group_before_event,
- delta_ids=delta_ids,
- current_state_ids=None,
- )
+ unpersisted_context = await self.calculate_context_info(
+ event=event,
+ state_ids_before_event=state_ids_before_event,
+ partial_state=partial_state,
)
- return EventContext.with_state(
- storage=self._storage_controllers,
- state_group=state_group_after_event,
- state_group_before_event=state_group_before_event,
- state_delta_due_to_event=delta_ids,
- prev_group=state_group_before_event,
- delta_ids=delta_ids,
- partial_state=False,
- )
+ return await unpersisted_context.persist(event)
@measure_func()
async def resolve_state_groups_for_events(
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 41d9111019..481fec72fe 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -37,6 +37,8 @@ class SQLBaseStore(metaclass=ABCMeta):
per data store (and not one per physical database).
"""
+ db_pool: DatabasePool
+
def __init__(
self,
database: DatabasePool,
diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py
index 52efd4a171..9d7a8a792f 100644
--- a/synapse/storage/controllers/state.py
+++ b/synapse/storage/controllers/state.py
@@ -14,6 +14,7 @@
import logging
from typing import (
TYPE_CHECKING,
+ AbstractSet,
Any,
Awaitable,
Callable,
@@ -23,7 +24,6 @@ from typing import (
List,
Mapping,
Optional,
- Set,
Tuple,
)
@@ -527,7 +527,7 @@ class StateStorageController:
)
return state_map.get(key)
- async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
+ async def get_current_hosts_in_room(self, room_id: str) -> AbstractSet[str]:
"""Get current hosts in room based on current state.
Blocks until we have full state for the given room. This only happens for rooms
@@ -584,7 +584,7 @@ class StateStorageController:
async def get_users_in_room_with_profiles(
self, room_id: str
- ) -> Dict[str, ProfileInfo]:
+ ) -> Mapping[str, ProfileInfo]:
"""
Get the current users in the room with their profiles.
If the room is currently partial-stated, this will block until the room has
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index e20c5c5302..feaa6cdd07 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -499,6 +499,7 @@ class DatabasePool:
"""
_TXN_ID = 0
+ engine: BaseDatabaseEngine
def __init__(
self,
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 8a359d7eb8..95567826f2 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -21,6 +21,7 @@ from typing import (
FrozenSet,
Iterable,
List,
+ Mapping,
Optional,
Tuple,
cast,
@@ -122,25 +123,25 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
return self._account_data_id_gen.get_current_token()
@cached()
- async def get_account_data_for_user(
+ async def get_global_account_data_for_user(
self, user_id: str
- ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
+ ) -> Mapping[str, JsonDict]:
"""
- Get all the client account_data for a user.
+ Get all the global client account_data for a user.
If experimental MSC3391 support is enabled, any entries with an empty
content body are excluded; as this means they have been deleted.
Args:
user_id: The user to get the account_data for.
+
Returns:
- A 2-tuple of a dict of global account_data and a dict mapping from
- room_id string to per room account_data dicts.
+ The global account_data.
"""
- def get_account_data_for_user_txn(
+ def get_global_account_data_for_user(
txn: LoggingTransaction,
- ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
+ ) -> Dict[str, JsonDict]:
# The 'content != '{}' condition below prevents us from using
# `simple_select_list_txn` here, as it doesn't support conditions
# other than 'equals'.
@@ -158,10 +159,34 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
txn.execute(sql, (user_id,))
rows = self.db_pool.cursor_to_dict(txn)
- global_account_data = {
+ return {
row["account_data_type"]: db_to_json(row["content"]) for row in rows
}
+ return await self.db_pool.runInteraction(
+ "get_global_account_data_for_user", get_global_account_data_for_user
+ )
+
+ @cached()
+ async def get_room_account_data_for_user(
+ self, user_id: str
+ ) -> Mapping[str, Mapping[str, JsonDict]]:
+ """
+ Get all of the per-room client account_data for a user.
+
+ If experimental MSC3391 support is enabled, any entries with an empty
+ content body are excluded; as this means they have been deleted.
+
+ Args:
+ user_id: The user to get the account_data for.
+
+ Returns:
+ A dict mapping from room_id string to per-room account_data dicts.
+ """
+
+ def get_room_account_data_for_user_txn(
+ txn: LoggingTransaction,
+ ) -> Dict[str, Dict[str, JsonDict]]:
# The 'content != '{}' condition below prevents us from using
# `simple_select_list_txn` here, as it doesn't support conditions
# other than 'equals'.
@@ -185,10 +210,10 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
room_data[row["account_data_type"]] = db_to_json(row["content"])
- return global_account_data, by_room
+ return by_room
return await self.db_pool.runInteraction(
- "get_account_data_for_user", get_account_data_for_user_txn
+ "get_room_account_data_for_user_txn", get_room_account_data_for_user_txn
)
@cached(num_args=2, max_entries=5000, tree=True)
@@ -215,7 +240,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
@cached(num_args=2, tree=True)
async def get_account_data_for_room(
self, user_id: str, room_id: str
- ) -> Dict[str, JsonDict]:
+ ) -> Mapping[str, JsonDict]:
"""Get all the client account_data for a user for a room.
Args:
@@ -342,36 +367,61 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
"get_updated_room_account_data", get_updated_room_account_data_txn
)
- async def get_updated_account_data_for_user(
+ async def get_updated_global_account_data_for_user(
self, user_id: str, stream_id: int
- ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
- """Get all the client account_data for a that's changed for a user
+ ) -> Dict[str, JsonDict]:
+ """Get all the global account_data that's changed for a user.
Args:
user_id: The user to get the account_data for.
stream_id: The point in the stream since which to get updates
+
Returns:
- A deferred pair of a dict of global account_data and a dict
- mapping from room_id string to per room account_data dicts.
+ A dict of global account_data.
"""
- def get_updated_account_data_for_user_txn(
+ def get_updated_global_account_data_for_user(
txn: LoggingTransaction,
- ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
- sql = (
- "SELECT account_data_type, content FROM account_data"
- " WHERE user_id = ? AND stream_id > ?"
- )
-
+ ) -> Dict[str, JsonDict]:
+ sql = """
+ SELECT account_data_type, content FROM account_data
+ WHERE user_id = ? AND stream_id > ?
+ """
txn.execute(sql, (user_id, stream_id))
- global_account_data = {row[0]: db_to_json(row[1]) for row in txn}
+ return {row[0]: db_to_json(row[1]) for row in txn}
- sql = (
- "SELECT room_id, account_data_type, content FROM room_account_data"
- " WHERE user_id = ? AND stream_id > ?"
- )
+ changed = self._account_data_stream_cache.has_entity_changed(
+ user_id, int(stream_id)
+ )
+ if not changed:
+ return {}
+
+ return await self.db_pool.runInteraction(
+ "get_updated_global_account_data_for_user",
+ get_updated_global_account_data_for_user,
+ )
+
+ async def get_updated_room_account_data_for_user(
+ self, user_id: str, stream_id: int
+ ) -> Dict[str, Dict[str, JsonDict]]:
+ """Get all the room account_data that's changed for a user.
+ Args:
+ user_id: The user to get the account_data for.
+ stream_id: The point in the stream since which to get updates
+
+ Returns:
+ A dict mapping from room_id string to per room account_data dicts.
+ """
+
+ def get_updated_room_account_data_for_user_txn(
+ txn: LoggingTransaction,
+ ) -> Dict[str, Dict[str, JsonDict]]:
+ sql = """
+ SELECT room_id, account_data_type, content FROM room_account_data
+ WHERE user_id = ? AND stream_id > ?
+ """
txn.execute(sql, (user_id, stream_id))
account_data_by_room: Dict[str, Dict[str, JsonDict]] = {}
@@ -379,16 +429,17 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
room_account_data = account_data_by_room.setdefault(row[0], {})
room_account_data[row[1]] = db_to_json(row[2])
- return global_account_data, account_data_by_room
+ return account_data_by_room
changed = self._account_data_stream_cache.has_entity_changed(
user_id, int(stream_id)
)
if not changed:
- return {}, {}
+ return {}
return await self.db_pool.runInteraction(
- "get_updated_account_data_for_user", get_updated_account_data_for_user_txn
+ "get_updated_room_account_data_for_user",
+ get_updated_room_account_data_for_user_txn,
)
@cached(max_entries=5000, iterable=True)
@@ -444,7 +495,8 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
self.get_global_account_data_by_type_for_user.invalidate(
(row.user_id, row.data_type)
)
- self.get_account_data_for_user.invalidate((row.user_id,))
+ self.get_global_account_data_for_user.invalidate((row.user_id,))
+ self.get_room_account_data_for_user.invalidate((row.user_id,))
self.get_account_data_for_room.invalidate((row.user_id, row.room_id))
self.get_account_data_for_room_and_type.invalidate(
(row.user_id, row.room_id, row.data_type)
@@ -492,7 +544,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
)
self._account_data_stream_cache.entity_has_changed(user_id, next_id)
- self.get_account_data_for_user.invalidate((user_id,))
+ self.get_room_account_data_for_user.invalidate((user_id,))
self.get_account_data_for_room.invalidate((user_id, room_id))
self.get_account_data_for_room_and_type.prefill(
(user_id, room_id, account_data_type), content
@@ -558,7 +610,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
return None
self._account_data_stream_cache.entity_has_changed(user_id, next_id)
- self.get_account_data_for_user.invalidate((user_id,))
+ self.get_room_account_data_for_user.invalidate((user_id,))
self.get_account_data_for_room.invalidate((user_id, room_id))
self.get_account_data_for_room_and_type.prefill(
(user_id, room_id, account_data_type), {}
@@ -593,7 +645,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
)
self._account_data_stream_cache.entity_has_changed(user_id, next_id)
- self.get_account_data_for_user.invalidate((user_id,))
+ self.get_global_account_data_for_user.invalidate((user_id,))
self.get_global_account_data_by_type_for_user.invalidate(
(user_id, account_data_type)
)
@@ -761,7 +813,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
return None
self._account_data_stream_cache.entity_has_changed(user_id, next_id)
- self.get_account_data_for_user.invalidate((user_id,))
+ self.get_global_account_data_for_user.invalidate((user_id,))
self.get_global_account_data_by_type_for_user.prefill(
(user_id, account_data_type), {}
)
@@ -822,7 +874,10 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
txn, self.get_account_data_for_room_and_type, (user_id,)
)
self._invalidate_cache_and_stream(
- txn, self.get_account_data_for_user, (user_id,)
+ txn, self.get_global_account_data_for_user, (user_id,)
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_room_account_data_for_user, (user_id,)
)
self._invalidate_cache_and_stream(
txn, self.get_global_account_data_by_type_for_user, (user_id,)
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 5fb152c4ff..484db175d0 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -166,7 +166,7 @@ class ApplicationServiceWorkerStore(RoomMemberWorkerStore):
room_id: str,
app_service: "ApplicationService",
cache_context: _CacheContext,
- ) -> List[str]:
+ ) -> Sequence[str]:
"""
Get all users in a room that the appservice controls.
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index e8b6cc6b80..1ca66d57d4 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -21,6 +21,7 @@ from typing import (
Dict,
Iterable,
List,
+ Mapping,
Optional,
Set,
Tuple,
@@ -100,6 +101,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
("device_lists_outbound_pokes", "stream_id"),
("device_lists_changes_in_room", "stream_id"),
("device_lists_remote_pending", "stream_id"),
+ ("device_lists_changes_converted_stream_position", "stream_id"),
],
is_writer=hs.config.worker.worker_app is None,
)
@@ -201,7 +203,9 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
def get_device_stream_token(self) -> int:
return self._device_list_id_gen.get_current_token()
- async def count_devices_by_users(self, user_ids: Optional[List[str]] = None) -> int:
+ async def count_devices_by_users(
+ self, user_ids: Optional[Collection[str]] = None
+ ) -> int:
"""Retrieve number of all devices of given users.
Only returns number of devices that are not marked as hidden.
@@ -212,7 +216,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
"""
def count_devices_by_users_txn(
- txn: LoggingTransaction, user_ids: List[str]
+ txn: LoggingTransaction, user_ids: Collection[str]
) -> int:
sql = """
SELECT count(*)
@@ -745,42 +749,47 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
@trace
@cancellable
async def get_user_devices_from_cache(
- self, query_list: List[Tuple[str, Optional[str]]]
- ) -> Tuple[Set[str], Dict[str, Dict[str, JsonDict]]]:
+ self, user_ids: Set[str], user_and_device_ids: List[Tuple[str, str]]
+ ) -> Tuple[Set[str], Dict[str, Mapping[str, JsonDict]]]:
"""Get the devices (and keys if any) for remote users from the cache.
Args:
- query_list: List of (user_id, device_ids), if device_ids is
- falsey then return all device ids for that user.
+ user_ids: users which should have all device IDs returned
+ user_and_device_ids: List of (user_id, device_ids)
Returns:
A tuple of (user_ids_not_in_cache, results_map), where
user_ids_not_in_cache is a set of user_ids and results_map is a
mapping of user_id -> device_id -> device_info.
"""
- user_ids = {user_id for user_id, _ in query_list}
- user_map = await self.get_device_list_last_stream_id_for_remotes(list(user_ids))
+ unique_user_ids = user_ids | {user_id for user_id, _ in user_and_device_ids}
+ user_map = await self.get_device_list_last_stream_id_for_remotes(
+ list(unique_user_ids)
+ )
# We go and check if any of the users need to have their device lists
# resynced. If they do then we remove them from the cached list.
users_needing_resync = await self.get_user_ids_requiring_device_list_resync(
- user_ids
+ unique_user_ids
)
user_ids_in_cache = {
user_id for user_id, stream_id in user_map.items() if stream_id
} - users_needing_resync
- user_ids_not_in_cache = user_ids - user_ids_in_cache
-
- results: Dict[str, Dict[str, JsonDict]] = {}
- for user_id, device_id in query_list:
- if user_id not in user_ids_in_cache:
- continue
+ user_ids_not_in_cache = unique_user_ids - user_ids_in_cache
- if device_id:
- device = await self._get_cached_user_device(user_id, device_id)
- results.setdefault(user_id, {})[device_id] = device
- else:
+ # First fetch all the users which all devices are to be returned.
+ results: Dict[str, Mapping[str, JsonDict]] = {}
+ for user_id in user_ids:
+ if user_id in user_ids_in_cache:
results[user_id] = await self.get_cached_devices_for_user(user_id)
+ # Then fetch all device-specific requests, but skip users we've already
+ # fetched all devices for.
+ device_specific_results: Dict[str, Dict[str, JsonDict]] = {}
+ for user_id, device_id in user_and_device_ids:
+ if user_id in user_ids_in_cache and user_id not in user_ids:
+ device = await self._get_cached_user_device(user_id, device_id)
+ device_specific_results.setdefault(user_id, {})[device_id] = device
+ results.update(device_specific_results)
set_tag("in_cache", str(results))
set_tag("not_in_cache", str(user_ids_not_in_cache))
@@ -798,7 +807,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
return db_to_json(content)
@cached()
- async def get_cached_devices_for_user(self, user_id: str) -> Dict[str, JsonDict]:
+ async def get_cached_devices_for_user(self, user_id: str) -> Mapping[str, JsonDict]:
devices = await self.db_pool.simple_select_list(
table="device_lists_remote_cache",
keyvalues={"user_id": user_id},
diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py
index 5903fdaf00..44aa181174 100644
--- a/synapse/storage/databases/main/directory.py
+++ b/synapse/storage/databases/main/directory.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Iterable, List, Optional, Tuple
+from typing import Iterable, List, Optional, Sequence, Tuple
import attr
@@ -74,7 +74,7 @@ class DirectoryWorkerStore(CacheInvalidationWorkerStore):
)
@cached(max_entries=5000)
- async def get_aliases_for_room(self, room_id: str) -> List[str]:
+ async def get_aliases_for_room(self, room_id: str) -> Sequence[str]:
return await self.db_pool.simple_select_onecol(
"room_aliases",
{"room_id": room_id},
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index c4ac6c33ba..2c2d145666 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -20,7 +20,9 @@ from typing import (
Dict,
Iterable,
List,
+ Mapping,
Optional,
+ Sequence,
Tuple,
Union,
cast,
@@ -260,7 +262,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
for batch in batch_iter(signature_query, 50):
cross_sigs_result = await self.db_pool.runInteraction(
- "get_e2e_cross_signing_signatures",
+ "get_e2e_cross_signing_signatures_for_devices",
self._get_e2e_cross_signing_signatures_for_devices_txn,
batch,
)
@@ -691,7 +693,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
@cached(max_entries=10000)
async def get_e2e_unused_fallback_key_types(
self, user_id: str, device_id: str
- ) -> List[str]:
+ ) -> Sequence[str]:
"""Returns the fallback key types that have an unused key.
Args:
@@ -731,7 +733,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
return user_keys.get(key_type)
@cached(num_args=1)
- def _get_bare_e2e_cross_signing_keys(self, user_id: str) -> Dict[str, JsonDict]:
+ def _get_bare_e2e_cross_signing_keys(self, user_id: str) -> Mapping[str, JsonDict]:
"""Dummy function. Only used to make a cache for
_get_bare_e2e_cross_signing_keys_bulk.
"""
@@ -744,7 +746,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
)
async def _get_bare_e2e_cross_signing_keys_bulk(
self, user_ids: Iterable[str]
- ) -> Dict[str, Optional[Dict[str, JsonDict]]]:
+ ) -> Dict[str, Optional[Mapping[str, JsonDict]]]:
"""Returns the cross-signing keys for a set of users. The output of this
function should be passed to _get_e2e_cross_signing_signatures_txn if
the signatures for the calling user need to be fetched.
@@ -765,7 +767,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
)
# The `Optional` comes from the `@cachedList` decorator.
- return cast(Dict[str, Optional[Dict[str, JsonDict]]], result)
+ return cast(Dict[str, Optional[Mapping[str, JsonDict]]], result)
def _get_bare_e2e_cross_signing_keys_bulk_txn(
self,
@@ -924,7 +926,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
@cancellable
async def get_e2e_cross_signing_keys_bulk(
self, user_ids: List[str], from_user_id: Optional[str] = None
- ) -> Dict[str, Optional[Dict[str, JsonDict]]]:
+ ) -> Dict[str, Optional[Mapping[str, JsonDict]]]:
"""Returns the cross-signing keys for a set of users.
Args:
@@ -940,11 +942,14 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
result = await self._get_bare_e2e_cross_signing_keys_bulk(user_ids)
if from_user_id:
- result = await self.db_pool.runInteraction(
- "get_e2e_cross_signing_signatures",
- self._get_e2e_cross_signing_signatures_txn,
- result,
- from_user_id,
+ result = cast(
+ Dict[str, Optional[Mapping[str, JsonDict]]],
+ await self.db_pool.runInteraction(
+ "get_e2e_cross_signing_signatures",
+ self._get_e2e_cross_signing_signatures_txn,
+ result,
+ from_user_id,
+ ),
)
return result
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index bbee02ab18..ca780cca36 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -22,6 +22,7 @@ from typing import (
Iterable,
List,
Optional,
+ Sequence,
Set,
Tuple,
cast,
@@ -1004,7 +1005,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
room_id,
)
- async def get_max_depth_of(self, event_ids: List[str]) -> Tuple[Optional[str], int]:
+ async def get_max_depth_of(
+ self, event_ids: Collection[str]
+ ) -> Tuple[Optional[str], int]:
"""Returns the event ID and depth for the event that has the max depth from a set of event IDs
Args:
@@ -1141,7 +1144,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
)
@cached(max_entries=5000, iterable=True)
- async def get_latest_event_ids_in_room(self, room_id: str) -> List[str]:
+ async def get_latest_event_ids_in_room(self, room_id: str) -> Sequence[str]:
return await self.db_pool.simple_select_onecol(
table="event_forward_extremities",
keyvalues={"room_id": room_id},
@@ -1171,7 +1174,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
@cancellable
async def get_forward_extremities_for_room_at_stream_ordering(
self, room_id: str, stream_ordering: int
- ) -> List[str]:
+ ) -> Sequence[str]:
"""For a given room_id and stream_ordering, return the forward
extremeties of the room at that point in "time".
@@ -1204,7 +1207,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
@cached(max_entries=5000, num_args=2)
async def _get_forward_extremeties_for_room(
self, room_id: str, stream_ordering: int
- ) -> List[str]:
+ ) -> Sequence[str]:
"""For a given room_id and stream_ordering, return the forward
extremeties of the room at that point in "time".
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 3a0c370fde..eeccf5db24 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -203,11 +203,18 @@ class RoomNotifCounts:
# Map of thread ID to the notification counts.
threads: Dict[str, NotifCounts]
+ @staticmethod
+ def empty() -> "RoomNotifCounts":
+ return _EMPTY_ROOM_NOTIF_COUNTS
+
def __len__(self) -> int:
# To properly account for the amount of space in any caches.
return len(self.threads) + 1
+_EMPTY_ROOM_NOTIF_COUNTS = RoomNotifCounts(NotifCounts(), {})
+
+
def _serialize_action(
actions: Collection[Union[Mapping, str]], is_highlight: bool
) -> str:
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 1536937b67..7996cbb557 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -16,7 +16,6 @@
import itertools
import logging
from collections import OrderedDict
-from http import HTTPStatus
from typing import (
TYPE_CHECKING,
Any,
@@ -26,7 +25,6 @@ from typing import (
Iterable,
List,
Optional,
- Sequence,
Set,
Tuple,
)
@@ -36,7 +34,7 @@ from prometheus_client import Counter
import synapse.metrics
from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
-from synapse.api.errors import Codes, SynapseError
+from synapse.api.errors import PartialStateConflictError
from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase, relation_from_event
from synapse.events.snapshot import EventContext
@@ -52,7 +50,7 @@ from synapse.storage.databases.main.search import SearchEntry
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import AbstractStreamIdGenerator
from synapse.storage.util.sequence import SequenceGenerator
-from synapse.types import JsonDict, StateMap, get_domain_from_id
+from synapse.types import JsonDict, StateMap, StrCollection, get_domain_from_id
from synapse.util import json_encoder
from synapse.util.iterutils import batch_iter, sorted_topologically
from synapse.util.stringutils import non_null_str_or_none
@@ -72,24 +70,6 @@ event_counter = Counter(
)
-class PartialStateConflictError(SynapseError):
- """An internal error raised when attempting to persist an event with partial state
- after the room containing the event has been un-partial stated.
-
- This error should be handled by recomputing the event context and trying again.
-
- This error has an HTTP status code so that it can be transported over replication.
- It should not be exposed to clients.
- """
-
- def __init__(self) -> None:
- super().__init__(
- HTTPStatus.CONFLICT,
- msg="Cannot persist partial state event in un-partial stated room",
- errcode=Codes.UNKNOWN,
- )
-
-
@attr.s(slots=True, auto_attribs=True)
class DeltaState:
"""Deltas to use to update the `current_state_events` table.
@@ -306,7 +286,7 @@ class PersistEventsStore:
# The set of event_ids to return. This includes all soft-failed events
# and their prev events.
- existing_prevs = set()
+ existing_prevs: Set[str] = set()
def _get_prevs_before_rejected_txn(
txn: LoggingTransaction, batch: Collection[str]
@@ -571,7 +551,7 @@ class PersistEventsStore:
event_chain_id_gen: SequenceGenerator,
event_to_room_id: Dict[str, str],
event_to_types: Dict[str, Tuple[str, str]],
- event_to_auth_chain: Dict[str, Sequence[str]],
+ event_to_auth_chain: Dict[str, StrCollection],
) -> None:
"""Calculate the chain cover index for the given events.
@@ -865,7 +845,7 @@ class PersistEventsStore:
event_chain_id_gen: SequenceGenerator,
event_to_room_id: Dict[str, str],
event_to_types: Dict[str, Tuple[str, str]],
- event_to_auth_chain: Dict[str, Sequence[str]],
+ event_to_auth_chain: Dict[str, StrCollection],
events_to_calc_chain_id_for: Set[str],
chain_map: Dict[str, Tuple[int, int]],
) -> Dict[str, Tuple[int, int]]:
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index b9d3c36d60..584536111d 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -13,7 +13,7 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, Tuple, cast
+from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, cast
import attr
@@ -29,7 +29,7 @@ from synapse.storage.database import (
)
from synapse.storage.databases.main.events import PersistEventsStore
from synapse.storage.types import Cursor
-from synapse.types import JsonDict
+from synapse.types import JsonDict, StrCollection
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -1061,7 +1061,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
self.event_chain_id_gen, # type: ignore[attr-defined]
event_to_room_id,
event_to_types,
- cast(Dict[str, Sequence[str]], event_to_auth_chain),
+ cast(Dict[str, StrCollection], event_to_auth_chain),
)
return _CalculateChainCover(
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index db9a24db5e..4b1061e6d7 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast
+from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, cast
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage.database import (
@@ -95,7 +95,7 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore):
return await self.db_pool.runInteraction("count_users", _count_users)
@cached(num_args=0)
- async def get_monthly_active_count_by_service(self) -> Dict[str, int]:
+ async def get_monthly_active_count_by_service(self) -> Mapping[str, int]:
"""Generates current count of monthly active users broken down by service.
A service is typically an appservice but also includes native matrix users.
Since the `monthly_active_users` table is populated from the `user_ips` table
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 29972d5204..dddf49c2d5 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -21,7 +21,9 @@ from typing import (
Dict,
Iterable,
List,
+ Mapping,
Optional,
+ Sequence,
Tuple,
cast,
)
@@ -288,7 +290,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
async def get_linearized_receipts_for_room(
self, room_id: str, to_key: int, from_key: Optional[int] = None
- ) -> List[dict]:
+ ) -> Sequence[JsonDict]:
"""Get receipts for a single room for sending to clients.
Args:
@@ -311,7 +313,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
@cached(tree=True)
async def _get_linearized_receipts_for_room(
self, room_id: str, to_key: int, from_key: Optional[int] = None
- ) -> List[JsonDict]:
+ ) -> Sequence[JsonDict]:
"""See get_linearized_receipts_for_room"""
def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
@@ -354,7 +356,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
)
async def _get_linearized_receipts_for_rooms(
self, room_ids: Collection[str], to_key: int, from_key: Optional[int] = None
- ) -> Dict[str, List[JsonDict]]:
+ ) -> Dict[str, Sequence[JsonDict]]:
if not room_ids:
return {}
@@ -416,7 +418,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
)
async def get_linearized_receipts_for_all_rooms(
self, to_key: int, from_key: Optional[int] = None
- ) -> Dict[str, JsonDict]:
+ ) -> Mapping[str, JsonDict]:
"""Get receipts for all rooms between two stream_ids, up
to a limit of the latest 100 read receipts.
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 31f0f2bd3d..9a55e17624 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -16,7 +16,7 @@
import logging
import random
import re
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
+from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union, cast
import attr
@@ -192,7 +192,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
)
@cached()
- async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
+ async def get_user_by_id(self, user_id: str) -> Optional[Mapping[str, Any]]:
"""Deprecated: use get_userinfo_by_id instead"""
def get_user_by_id_txn(txn: LoggingTransaction) -> Optional[Dict[str, Any]]:
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 0018d6f7ab..fa3266c081 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -22,6 +22,7 @@ from typing import (
List,
Mapping,
Optional,
+ Sequence,
Set,
Tuple,
Union,
@@ -171,7 +172,7 @@ class RelationsWorkerStore(SQLBaseStore):
direction: Direction = Direction.BACKWARDS,
from_token: Optional[StreamToken] = None,
to_token: Optional[StreamToken] = None,
- ) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]:
+ ) -> Tuple[Sequence[_RelatedEvent], Optional[StreamToken]]:
"""Get a list of relations for an event, ordered by topological ordering.
Args:
@@ -397,7 +398,9 @@ class RelationsWorkerStore(SQLBaseStore):
return result is not None
@cached()
- async def get_aggregation_groups_for_event(self, event_id: str) -> List[JsonDict]:
+ async def get_aggregation_groups_for_event(
+ self, event_id: str
+ ) -> Sequence[JsonDict]:
raise NotImplementedError()
@cachedList(
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index ea6a5e2f34..694a5b802c 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -24,6 +24,7 @@ from typing import (
List,
Mapping,
Optional,
+ Sequence,
Set,
Tuple,
Union,
@@ -153,7 +154,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return self._known_servers_count
@cached(max_entries=100000, iterable=True)
- async def get_users_in_room(self, room_id: str) -> List[str]:
+ async def get_users_in_room(self, room_id: str) -> Sequence[str]:
"""Returns a list of users in the room.
Will return inaccurate results for rooms with partial state, since the state for
@@ -190,9 +191,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
@cached()
- def get_user_in_room_with_profile(
- self, room_id: str, user_id: str
- ) -> Dict[str, ProfileInfo]:
+ def get_user_in_room_with_profile(self, room_id: str, user_id: str) -> ProfileInfo:
raise NotImplementedError()
@cachedList(
@@ -246,7 +245,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
@cached(max_entries=100000, iterable=True)
async def get_users_in_room_with_profiles(
self, room_id: str
- ) -> Dict[str, ProfileInfo]:
+ ) -> Mapping[str, ProfileInfo]:
"""Get a mapping from user ID to profile information for all users in a given room.
The profile information comes directly from this room's `m.room.member`
@@ -285,7 +284,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
@cached(max_entries=100000)
- async def get_room_summary(self, room_id: str) -> Dict[str, MemberSummary]:
+ async def get_room_summary(self, room_id: str) -> Mapping[str, MemberSummary]:
"""Get the details of a room roughly suitable for use by the room
summary extension to /sync. Useful when lazy loading room members.
Args:
@@ -357,7 +356,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
@cached()
async def get_invited_rooms_for_local_user(
self, user_id: str
- ) -> List[RoomsForUser]:
+ ) -> Sequence[RoomsForUser]:
"""Get all the rooms the *local* user is invited to.
Args:
@@ -475,7 +474,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return results
@cached(iterable=True)
- async def get_local_users_in_room(self, room_id: str) -> List[str]:
+ async def get_local_users_in_room(self, room_id: str) -> Sequence[str]:
"""
Retrieves a list of the current roommembers who are local to the server.
"""
@@ -791,7 +790,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
"""Returns the set of users who share a room with `user_id`"""
room_ids = await self.get_rooms_for_user(user_id)
- user_who_share_room = set()
+ user_who_share_room: Set[str] = set()
for room_id in room_ids:
user_ids = await self.get_users_in_room(room_id)
user_who_share_room.update(user_ids)
@@ -953,7 +952,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return True
@cached(iterable=True, max_entries=10000)
- async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
+ async def get_current_hosts_in_room(self, room_id: str) -> AbstractSet[str]:
"""Get current hosts in room based on current state."""
# First we check if we already have `get_users_in_room` in the cache, as
diff --git a/synapse/storage/databases/main/signatures.py b/synapse/storage/databases/main/signatures.py
index 05da15074a..5dcb1fc0b5 100644
--- a/synapse/storage/databases/main/signatures.py
+++ b/synapse/storage/databases/main/signatures.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Collection, Dict, List, Tuple
+from typing import Collection, Dict, List, Mapping, Tuple
from unpaddedbase64 import encode_base64
@@ -26,7 +26,7 @@ from synapse.util.caches.descriptors import cached, cachedList
class SignatureWorkerStore(EventsWorkerStore):
@cached()
- def get_event_reference_hash(self, event_id: str) -> Dict[str, Dict[str, bytes]]:
+ def get_event_reference_hash(self, event_id: str) -> Mapping[str, bytes]:
# This is a dummy function to allow get_event_reference_hashes
# to use its cache
raise NotImplementedError()
@@ -36,7 +36,7 @@ class SignatureWorkerStore(EventsWorkerStore):
)
async def get_event_reference_hashes(
self, event_ids: Collection[str]
- ) -> Dict[str, Dict[str, bytes]]:
+ ) -> Mapping[str, Mapping[str, bytes]]:
"""Get all hashes for given events.
Args:
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index d5500cdd47..c149a9eacb 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -15,7 +15,7 @@
# limitations under the License.
import logging
-from typing import Any, Dict, Iterable, List, Tuple, cast
+from typing import Any, Dict, Iterable, List, Mapping, Tuple, cast
from synapse.api.constants import AccountDataTypes
from synapse.replication.tcp.streams import AccountDataStream
@@ -32,7 +32,9 @@ logger = logging.getLogger(__name__)
class TagsWorkerStore(AccountDataWorkerStore):
@cached()
- async def get_tags_for_user(self, user_id: str) -> Dict[str, Dict[str, JsonDict]]:
+ async def get_tags_for_user(
+ self, user_id: str
+ ) -> Mapping[str, Mapping[str, JsonDict]]:
"""Get all the tags for a user.
@@ -107,7 +109,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
async def get_updated_tags(
self, user_id: str, stream_id: int
- ) -> Dict[str, Dict[str, JsonDict]]:
+ ) -> Mapping[str, Mapping[str, JsonDict]]:
"""Get all the tags for the rooms where the tags have changed since the
given version
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index 14ef5b040d..f6a6fd4079 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -16,9 +16,9 @@ import logging
import re
from typing import (
TYPE_CHECKING,
- Dict,
Iterable,
List,
+ Mapping,
Optional,
Sequence,
Set,
@@ -586,7 +586,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
)
@cached()
- async def get_user_in_directory(self, user_id: str) -> Optional[Dict[str, str]]:
+ async def get_user_in_directory(self, user_id: str) -> Optional[Mapping[str, str]]:
return await self.db_pool.simple_select_one(
table="user_directory",
keyvalues={"user_id": user_id},
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 3acdb39da7..6c335a9315 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -23,7 +23,7 @@ from typing_extensions import Counter as CounterType
from synapse.config.homeserver import HomeServerConfig
from synapse.storage.database import LoggingDatabaseConnection
-from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
+from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
from synapse.storage.schema import SCHEMA_COMPAT_VERSION, SCHEMA_VERSION
from synapse.storage.types import Cursor
@@ -108,9 +108,14 @@ def prepare_database(
# so we start one before running anything. This ensures that any upgrades
# are either applied completely, or not at all.
#
- # (psycopg2 automatically starts a transaction as soon as we run any statements
- # at all, so this is redundant but harmless there.)
- cur.execute("BEGIN TRANSACTION")
+ # psycopg2 does not automatically start transactions when in autocommit mode.
+ # While it is technically harmless to nest transactions in postgres, doing so
+ # results in a warning in Postgres' logs per query. And we'd rather like to
+ # avoid doing that.
+ if isinstance(database_engine, Sqlite3Engine) or (
+ isinstance(database_engine, PostgresEngine) and db_conn.autocommit
+ ):
+ cur.execute("BEGIN TRANSACTION")
logger.info("%r: Checking existing schema version", databases)
version_info = _get_or_create_schema_state(cur, database_engine)
diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py
index f82d1cfc29..52e366c8ae 100644
--- a/synapse/types/__init__.py
+++ b/synapse/types/__init__.py
@@ -69,6 +69,8 @@ StateMap = Mapping[StateKey, T]
MutableStateMap = MutableMapping[StateKey, T]
# JSON types. These could be made stronger, but will do for now.
+# A "simple" (canonical) JSON value.
+SimpleJsonValue = Optional[Union[str, int, bool]]
# A JSON-serialisable dict.
JsonDict = Dict[str, Any]
# A JSON-serialisable mapping; roughly speaking an immutable JSONDict.
|