summary refs log tree commit diff
path: root/synapse/handlers/message.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/message.py')
-rw-r--r--synapse/handlers/message.py52
1 files changed, 39 insertions, 13 deletions
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py

index da129ec16a..d283a938c0 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py
@@ -16,6 +16,7 @@ # limitations under the License. import logging import random +from builtins import dict from http import HTTPStatus from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple @@ -577,7 +578,7 @@ class EventCreationHandler: state_map: Optional[StateMap[str]] = None, for_batch: bool = False, current_state_group: Optional[int] = None, - ) -> Tuple[EventBase, UnpersistedEventContextBase]: + ) -> Tuple[EventBase, UnpersistedEventContextBase, Optional[dict]]: """ Given a dict from a client, create a new event. If bool for_batch is true, will create an event using the prev_event_ids, and will create an event context for @@ -649,7 +650,9 @@ class EventCreationHandler: exceeded Returns: - Tuple of created event, Context + Tuple of created event, Context, and an optional event dict to form the basis + of a new event if third_party_rules would like to send an additional event as a + consequence of this event. """ await self.auth_blocking.check_auth_blocking(requester=requester) @@ -711,7 +714,7 @@ class EventCreationHandler: builder.internal_metadata.historical = historical - event, unpersisted_context = await self.create_new_client_event( + event, unpersisted_context, new_event = await self.create_new_client_event( builder=builder, requester=requester, allow_no_prev_events=allow_no_prev_events, @@ -765,7 +768,7 @@ class EventCreationHandler: ) self.validator.validate_new(event, self.config) - return event, unpersisted_context + return event, unpersisted_context, new_event async def _is_exempt_from_privacy_policy( self, builder: EventBuilder, requester: Requester @@ -1005,7 +1008,11 @@ class EventCreationHandler: max_retries = 5 for i in range(max_retries): try: - event, unpersisted_context = await self.create_event( + ( + event, + unpersisted_context, + third_party_event_dict, + ) = await self.create_event( requester, event_dict, txn_id=txn_id, @@ -1054,9 +1061,24 @@ class EventCreationHandler: Codes.FORBIDDEN, ) + events_and_context = [(event, context)] + if third_party_event_dict: + ( + third_party_event, + unpersisted_third_party_context, + _, + ) = await self.create_event( + requester, + third_party_event_dict, + ) + third_party_context = await unpersisted_third_party_context.persist( + third_party_event + ) + events_and_context.append((third_party_event, third_party_context)) + ev = await self.handle_new_client_event( requester=requester, - events_and_context=[(event, context)], + events_and_context=events_and_context, ratelimit=ratelimit, ignore_shadow_ban=ignore_shadow_ban, ) @@ -1086,7 +1108,7 @@ class EventCreationHandler: state_map: Optional[StateMap[str]] = None, for_batch: bool = False, current_state_group: Optional[int] = None, - ) -> Tuple[EventBase, UnpersistedEventContextBase]: + ) -> Tuple[EventBase, UnpersistedEventContextBase, Optional[dict]]: """Create a new event for a local client. If bool for_batch is true, will create an event using the prev_event_ids, and will create an event context for the event using the parameters state_map and current_state_group, thus these parameters @@ -1135,7 +1157,9 @@ class EventCreationHandler: batch persisting Returns: - Tuple of created event, UnpersistedEventContext + Tuple of created event, UnpersistedEventContext, and an optional event dict + to form the basis of a new event if third_party_rules would like to send an + additional event as a consequence of this event. """ # Strip down the state_event_ids to only what we need to auth the event. # For example, we don't need extra m.room.member that don't match event.sender @@ -1269,9 +1293,11 @@ class EventCreationHandler: if requester: context.app_service = requester.app_service - res, new_content = await self.third_party_event_rules.check_event_allowed( - event, context - ) + ( + res, + new_content, + new_event, + ) = await self.third_party_event_rules.check_event_allowed(event, context) if res is False: logger.info( "Event %s forbidden by third-party rules", @@ -1291,7 +1317,7 @@ class EventCreationHandler: await self._validate_event_relation(event) logger.debug("Created event %s", event.event_id) - return event, context + return event, context, new_event async def _validate_event_relation(self, event: EventBase) -> None: """ @@ -2046,7 +2072,7 @@ class EventCreationHandler: max_retries = 5 for i in range(max_retries): try: - event, unpersisted_context = await self.create_event( + event, unpersisted_context, _ = await self.create_event( requester, { "type": EventTypes.Dummy,