diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index da129ec16a..d283a938c0 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -16,6 +16,7 @@
# limitations under the License.
import logging
import random
+from builtins import dict
from http import HTTPStatus
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple
@@ -577,7 +578,7 @@ class EventCreationHandler:
state_map: Optional[StateMap[str]] = None,
for_batch: bool = False,
current_state_group: Optional[int] = None,
- ) -> Tuple[EventBase, UnpersistedEventContextBase]:
+ ) -> Tuple[EventBase, UnpersistedEventContextBase, Optional[dict]]:
"""
Given a dict from a client, create a new event. If bool for_batch is true, will
create an event using the prev_event_ids, and will create an event context for
@@ -649,7 +650,9 @@ class EventCreationHandler:
exceeded
Returns:
- Tuple of created event, Context
+ Tuple of created event, Context, and an optional event dict to form the basis
+ of a new event if third_party_rules would like to send an additional event as a
+ consequence of this event.
"""
await self.auth_blocking.check_auth_blocking(requester=requester)
@@ -711,7 +714,7 @@ class EventCreationHandler:
builder.internal_metadata.historical = historical
- event, unpersisted_context = await self.create_new_client_event(
+ event, unpersisted_context, new_event = await self.create_new_client_event(
builder=builder,
requester=requester,
allow_no_prev_events=allow_no_prev_events,
@@ -765,7 +768,7 @@ class EventCreationHandler:
)
self.validator.validate_new(event, self.config)
- return event, unpersisted_context
+ return event, unpersisted_context, new_event
async def _is_exempt_from_privacy_policy(
self, builder: EventBuilder, requester: Requester
@@ -1005,7 +1008,11 @@ class EventCreationHandler:
max_retries = 5
for i in range(max_retries):
try:
- event, unpersisted_context = await self.create_event(
+ (
+ event,
+ unpersisted_context,
+ third_party_event_dict,
+ ) = await self.create_event(
requester,
event_dict,
txn_id=txn_id,
@@ -1054,9 +1061,24 @@ class EventCreationHandler:
Codes.FORBIDDEN,
)
+ events_and_context = [(event, context)]
+ if third_party_event_dict:
+ (
+ third_party_event,
+ unpersisted_third_party_context,
+ _,
+ ) = await self.create_event(
+ requester,
+ third_party_event_dict,
+ )
+ third_party_context = await unpersisted_third_party_context.persist(
+ third_party_event
+ )
+ events_and_context.append((third_party_event, third_party_context))
+
ev = await self.handle_new_client_event(
requester=requester,
- events_and_context=[(event, context)],
+ events_and_context=events_and_context,
ratelimit=ratelimit,
ignore_shadow_ban=ignore_shadow_ban,
)
@@ -1086,7 +1108,7 @@ class EventCreationHandler:
state_map: Optional[StateMap[str]] = None,
for_batch: bool = False,
current_state_group: Optional[int] = None,
- ) -> Tuple[EventBase, UnpersistedEventContextBase]:
+ ) -> Tuple[EventBase, UnpersistedEventContextBase, Optional[dict]]:
"""Create a new event for a local client. If bool for_batch is true, will
create an event using the prev_event_ids, and will create an event context for
the event using the parameters state_map and current_state_group, thus these parameters
@@ -1135,7 +1157,9 @@ class EventCreationHandler:
batch persisting
Returns:
- Tuple of created event, UnpersistedEventContext
+ Tuple of created event, UnpersistedEventContext, and an optional event dict
+ to form the basis of a new event if third_party_rules would like to send an
+ additional event as a consequence of this event.
"""
# Strip down the state_event_ids to only what we need to auth the event.
# For example, we don't need extra m.room.member that don't match event.sender
@@ -1269,9 +1293,11 @@ class EventCreationHandler:
if requester:
context.app_service = requester.app_service
- res, new_content = await self.third_party_event_rules.check_event_allowed(
- event, context
- )
+ (
+ res,
+ new_content,
+ new_event,
+ ) = await self.third_party_event_rules.check_event_allowed(event, context)
if res is False:
logger.info(
"Event %s forbidden by third-party rules",
@@ -1291,7 +1317,7 @@ class EventCreationHandler:
await self._validate_event_relation(event)
logger.debug("Created event %s", event.event_id)
- return event, context
+ return event, context, new_event
async def _validate_event_relation(self, event: EventBase) -> None:
"""
@@ -2046,7 +2072,7 @@ class EventCreationHandler:
max_retries = 5
for i in range(max_retries):
try:
- event, unpersisted_context = await self.create_event(
+ event, unpersisted_context, _ = await self.create_event(
requester,
{
"type": EventTypes.Dummy,
|