diff options
Diffstat (limited to 'synapse/events')
-rw-r--r-- | synapse/events/__init__.py | 45 | ||||
-rw-r--r-- | synapse/events/snapshot.py | 180 | ||||
-rw-r--r-- | synapse/events/spamcheck.py | 81 |
3 files changed, 137 insertions, 169 deletions
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index c238376caf..39ad2793d9 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -15,6 +15,7 @@ # limitations under the License. import abc +import collections.abc import os from typing import ( TYPE_CHECKING, @@ -32,9 +33,11 @@ from typing import ( overload, ) +import attr from typing_extensions import Literal from unpaddedbase64 import encode_base64 +from synapse.api.constants import RelationTypes from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions from synapse.types import JsonDict, RoomStreamToken from synapse.util.caches import intern_dict @@ -615,3 +618,45 @@ def make_event_from_dict( return event_type( event_dict, room_version, internal_metadata_dict or {}, rejected_reason ) + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class _EventRelation: + # The target event of the relation. + parent_id: str + # The relation type. + rel_type: str + # The aggregation key. Will be None if the rel_type is not m.annotation or is + # not a string. + aggregation_key: Optional[str] + + +def relation_from_event(event: EventBase) -> Optional[_EventRelation]: + """ + Attempt to parse relation information an event. + + Returns: + The event relation information, if it is valid. None, otherwise. + """ + relation = event.content.get("m.relates_to") + if not relation or not isinstance(relation, collections.abc.Mapping): + # No relation information. + return None + + # Relations must have a type and parent event ID. + rel_type = relation.get("rel_type") + if not isinstance(rel_type, str): + return None + + parent_id = relation.get("event_id") + if not isinstance(parent_id, str): + return None + + # Annotations have a key field. + aggregation_key = None + if rel_type == RelationTypes.ANNOTATION: + aggregation_key = relation.get("key") + if not isinstance(aggregation_key, str): + aggregation_key = None + + return _EventRelation(parent_id, rel_type, aggregation_key) diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index 46042b2bf7..9ccd24b298 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -15,12 +15,10 @@ from typing import TYPE_CHECKING, List, Optional, Tuple, Union import attr from frozendict import frozendict - -from twisted.internet.defer import Deferred +from typing_extensions import Literal from synapse.appservice import ApplicationService from synapse.events import EventBase -from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.types import JsonDict, StateMap if TYPE_CHECKING: @@ -60,6 +58,9 @@ class EventContext: If ``state_group`` is None (ie, the event is an outlier), ``state_group_before_event`` will always also be ``None``. + state_delta_due_to_event: If `state_group` and `state_group_before_event` are not None + then this is the delta of the state between the two groups. + prev_group: If it is known, ``state_group``'s prev_group. Note that this being None does not necessarily mean that ``state_group`` does not have a prev_group! @@ -78,73 +79,47 @@ class EventContext: app_service: If this event is being sent by a (local) application service, that app service. - _current_state_ids: The room state map, including this event - ie, the state - in ``state_group``. - - (type, state_key) -> event_id - - For an outlier, this is {} - - Note that this is a private attribute: it should be accessed via - ``get_current_state_ids``. _AsyncEventContext impl calculates this - on-demand: it will be None until that happens. - - _prev_state_ids: The room state map, excluding this event - ie, the state - in ``state_group_before_event``. For a non-state - event, this will be the same as _current_state_events. - - Note that it is a completely different thing to prev_group! - - (type, state_key) -> event_id - - For an outlier, this is {} - - As with _current_state_ids, this is a private attribute. It should be - accessed via get_prev_state_ids. - partial_state: if True, we may be storing this event with a temporary, incomplete state. """ - rejected: Union[bool, str] = False + _storage: "Storage" + rejected: Union[Literal[False], str] = False _state_group: Optional[int] = None state_group_before_event: Optional[int] = None + _state_delta_due_to_event: Optional[StateMap[str]] = None prev_group: Optional[int] = None delta_ids: Optional[StateMap[str]] = None app_service: Optional[ApplicationService] = None - _current_state_ids: Optional[StateMap[str]] = None - _prev_state_ids: Optional[StateMap[str]] = None - partial_state: bool = False @staticmethod def with_state( + storage: "Storage", state_group: Optional[int], state_group_before_event: Optional[int], - current_state_ids: Optional[StateMap[str]], - prev_state_ids: Optional[StateMap[str]], + state_delta_due_to_event: Optional[StateMap[str]], partial_state: bool, prev_group: Optional[int] = None, delta_ids: Optional[StateMap[str]] = None, ) -> "EventContext": return EventContext( - current_state_ids=current_state_ids, - prev_state_ids=prev_state_ids, + storage=storage, state_group=state_group, state_group_before_event=state_group_before_event, + state_delta_due_to_event=state_delta_due_to_event, prev_group=prev_group, delta_ids=delta_ids, partial_state=partial_state, ) @staticmethod - def for_outlier() -> "EventContext": + def for_outlier( + storage: "Storage", + ) -> "EventContext": """Return an EventContext instance suitable for persisting an outlier event""" - return EventContext( - current_state_ids={}, - prev_state_ids={}, - ) + return EventContext(storage=storage) async def serialize(self, event: EventBase, store: "DataStore") -> JsonDict: """Converts self to a type that can be serialized as JSON, and then @@ -157,24 +132,14 @@ class EventContext: The serialized event. """ - # We don't serialize the full state dicts, instead they get pulled out - # of the DB on the other side. However, the other side can't figure out - # the prev_state_ids, so if we're a state event we include the event - # id that we replaced in the state. - if event.is_state(): - prev_state_ids = await self.get_prev_state_ids() - prev_state_id = prev_state_ids.get((event.type, event.state_key)) - else: - prev_state_id = None - return { - "prev_state_id": prev_state_id, - "event_type": event.type, - "event_state_key": event.get_state_key(), "state_group": self._state_group, "state_group_before_event": self.state_group_before_event, "rejected": self.rejected, "prev_group": self.prev_group, + "state_delta_due_to_event": _encode_state_dict( + self._state_delta_due_to_event + ), "delta_ids": _encode_state_dict(self.delta_ids), "app_service_id": self.app_service.id if self.app_service else None, "partial_state": self.partial_state, @@ -192,16 +157,16 @@ class EventContext: Returns: The event context. """ - context = _AsyncEventContextImpl( + context = EventContext( # We use the state_group and prev_state_id stuff to pull the # current_state_ids out of the DB and construct prev_state_ids. storage=storage, - prev_state_id=input["prev_state_id"], - event_type=input["event_type"], - event_state_key=input["event_state_key"], state_group=input["state_group"], state_group_before_event=input["state_group_before_event"], prev_group=input["prev_group"], + state_delta_due_to_event=_decode_state_dict( + input["state_delta_due_to_event"] + ), delta_ids=_decode_state_dict(input["delta_ids"]), rejected=input["rejected"], partial_state=input.get("partial_state", False), @@ -249,8 +214,15 @@ class EventContext: if self.rejected: raise RuntimeError("Attempt to access state_ids of rejected event") - await self._ensure_fetched() - return self._current_state_ids + assert self._state_delta_due_to_event is not None + + prev_state_ids = await self.get_prev_state_ids() + + if self._state_delta_due_to_event: + prev_state_ids = dict(prev_state_ids) + prev_state_ids.update(self._state_delta_due_to_event) + + return prev_state_ids async def get_prev_state_ids(self) -> StateMap[str]: """ @@ -265,94 +237,10 @@ class EventContext: Maps a (type, state_key) to the event ID of the state event matching this tuple. """ - await self._ensure_fetched() - # There *should* be previous state IDs now. - assert self._prev_state_ids is not None - return self._prev_state_ids - - def get_cached_current_state_ids(self) -> Optional[StateMap[str]]: - """Gets the current state IDs if we have them already cached. - - It is an error to access this for a rejected event, since rejected state should - not make it into the room state. This method will raise an exception if - ``rejected`` is set. - - Returns: - Returns None if we haven't cached the state or if state_group is None - (which happens when the associated event is an outlier). - - Otherwise, returns the the current state IDs. - """ - if self.rejected: - raise RuntimeError("Attempt to access state_ids of rejected event") - - return self._current_state_ids - - async def _ensure_fetched(self) -> None: - return None - - -@attr.s(slots=True) -class _AsyncEventContextImpl(EventContext): - """ - An implementation of EventContext which fetches _current_state_ids and - _prev_state_ids from the database on demand. - - Attributes: - - _storage - - _fetching_state_deferred: Resolves when *_state_ids have been calculated. - None if we haven't started calculating yet - - _event_type: The type of the event the context is associated with. - - _event_state_key: The state_key of the event the context is associated with. - - _prev_state_id: If the event associated with the context is a state event, - then `_prev_state_id` is the event_id of the state that was replaced. - """ - - # This needs to have a default as we're inheriting - _storage: "Storage" = attr.ib(default=None) - _prev_state_id: Optional[str] = attr.ib(default=None) - _event_type: str = attr.ib(default=None) - _event_state_key: Optional[str] = attr.ib(default=None) - _fetching_state_deferred: Optional["Deferred[None]"] = attr.ib(default=None) - - async def _ensure_fetched(self) -> None: - if not self._fetching_state_deferred: - self._fetching_state_deferred = run_in_background(self._fill_out_state) - - await make_deferred_yieldable(self._fetching_state_deferred) - - async def _fill_out_state(self) -> None: - """Called to populate the _current_state_ids and _prev_state_ids - attributes by loading from the database. - """ - if self.state_group is None: - # No state group means the event is an outlier. Usually the state_ids dicts are also - # pre-set to empty dicts, but they get reset when the context is serialized, so set - # them to empty dicts again here. - self._current_state_ids = {} - self._prev_state_ids = {} - return - - current_state_ids = await self._storage.state.get_state_ids_for_group( - self.state_group + assert self.state_group_before_event is not None + return await self._storage.state.get_state_ids_for_group( + self.state_group_before_event ) - # Set this separately so mypy knows current_state_ids is not None. - self._current_state_ids = current_state_ids - if self._event_state_key is not None: - self._prev_state_ids = dict(current_state_ids) - - key = (self._event_type, self._event_state_key) - if self._prev_state_id: - self._prev_state_ids[key] = self._prev_state_id - else: - self._prev_state_ids.pop(key, None) - else: - self._prev_state_ids = current_state_ids def _encode_state_dict( diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py index 3b6795d40f..f30207376a 100644 --- a/synapse/events/spamcheck.py +++ b/synapse/events/spamcheck.py @@ -32,6 +32,7 @@ from synapse.rest.media.v1.media_storage import ReadableFileWrapper from synapse.spam_checker_api import RegistrationBehaviour from synapse.types import RoomAlias, UserProfile from synapse.util.async_helpers import delay_cancellation, maybe_awaitable +from synapse.util.metrics import Measure if TYPE_CHECKING: import synapse.events @@ -162,7 +163,10 @@ def load_legacy_spam_checkers(hs: "synapse.server.HomeServer") -> None: class SpamChecker: - def __init__(self) -> None: + def __init__(self, hs: "synapse.server.HomeServer") -> None: + self.hs = hs + self.clock = hs.get_clock() + self._check_event_for_spam_callbacks: List[CHECK_EVENT_FOR_SPAM_CALLBACK] = [] self._user_may_join_room_callbacks: List[USER_MAY_JOIN_ROOM_CALLBACK] = [] self._user_may_invite_callbacks: List[USER_MAY_INVITE_CALLBACK] = [] @@ -255,7 +259,10 @@ class SpamChecker: will be used as the error message returned to the user. """ for callback in self._check_event_for_spam_callbacks: - res: Union[bool, str] = await delay_cancellation(callback(event)) + with Measure( + self.clock, "{}.{}".format(callback.__module__, callback.__qualname__) + ): + res: Union[bool, str] = await delay_cancellation(callback(event)) if res: return res @@ -276,9 +283,12 @@ class SpamChecker: Whether the user may join the room """ for callback in self._user_may_join_room_callbacks: - may_join_room = await delay_cancellation( - callback(user_id, room_id, is_invited) - ) + with Measure( + self.clock, "{}.{}".format(callback.__module__, callback.__qualname__) + ): + may_join_room = await delay_cancellation( + callback(user_id, room_id, is_invited) + ) if may_join_room is False: return False @@ -300,9 +310,12 @@ class SpamChecker: True if the user may send an invite, otherwise False """ for callback in self._user_may_invite_callbacks: - may_invite = await delay_cancellation( - callback(inviter_userid, invitee_userid, room_id) - ) + with Measure( + self.clock, "{}.{}".format(callback.__module__, callback.__qualname__) + ): + may_invite = await delay_cancellation( + callback(inviter_userid, invitee_userid, room_id) + ) if may_invite is False: return False @@ -328,9 +341,12 @@ class SpamChecker: True if the user may send the invite, otherwise False """ for callback in self._user_may_send_3pid_invite_callbacks: - may_send_3pid_invite = await delay_cancellation( - callback(inviter_userid, medium, address, room_id) - ) + with Measure( + self.clock, "{}.{}".format(callback.__module__, callback.__qualname__) + ): + may_send_3pid_invite = await delay_cancellation( + callback(inviter_userid, medium, address, room_id) + ) if may_send_3pid_invite is False: return False @@ -348,7 +364,10 @@ class SpamChecker: True if the user may create a room, otherwise False """ for callback in self._user_may_create_room_callbacks: - may_create_room = await delay_cancellation(callback(userid)) + with Measure( + self.clock, "{}.{}".format(callback.__module__, callback.__qualname__) + ): + may_create_room = await delay_cancellation(callback(userid)) if may_create_room is False: return False @@ -369,9 +388,12 @@ class SpamChecker: True if the user may create a room alias, otherwise False """ for callback in self._user_may_create_room_alias_callbacks: - may_create_room_alias = await delay_cancellation( - callback(userid, room_alias) - ) + with Measure( + self.clock, "{}.{}".format(callback.__module__, callback.__qualname__) + ): + may_create_room_alias = await delay_cancellation( + callback(userid, room_alias) + ) if may_create_room_alias is False: return False @@ -390,7 +412,10 @@ class SpamChecker: True if the user may publish the room, otherwise False """ for callback in self._user_may_publish_room_callbacks: - may_publish_room = await delay_cancellation(callback(userid, room_id)) + with Measure( + self.clock, "{}.{}".format(callback.__module__, callback.__qualname__) + ): + may_publish_room = await delay_cancellation(callback(userid, room_id)) if may_publish_room is False: return False @@ -412,9 +437,13 @@ class SpamChecker: True if the user is spammy. """ for callback in self._check_username_for_spam_callbacks: - # Make a copy of the user profile object to ensure the spam checker cannot - # modify it. - if await delay_cancellation(callback(user_profile.copy())): + with Measure( + self.clock, "{}.{}".format(callback.__module__, callback.__qualname__) + ): + # Make a copy of the user profile object to ensure the spam checker cannot + # modify it. + res = await delay_cancellation(callback(user_profile.copy())) + if res: return True return False @@ -442,9 +471,12 @@ class SpamChecker: """ for callback in self._check_registration_for_spam_callbacks: - behaviour = await delay_cancellation( - callback(email_threepid, username, request_info, auth_provider_id) - ) + with Measure( + self.clock, "{}.{}".format(callback.__module__, callback.__qualname__) + ): + behaviour = await delay_cancellation( + callback(email_threepid, username, request_info, auth_provider_id) + ) assert isinstance(behaviour, RegistrationBehaviour) if behaviour != RegistrationBehaviour.ALLOW: return behaviour @@ -486,7 +518,10 @@ class SpamChecker: """ for callback in self._check_media_file_for_spam_callbacks: - spam = await delay_cancellation(callback(file_wrapper, file_info)) + with Measure( + self.clock, "{}.{}".format(callback.__module__, callback.__qualname__) + ): + spam = await delay_cancellation(callback(file_wrapper, file_info)) if spam: return True |