diff --git a/changelog.d/15131.misc b/changelog.d/15131.misc
new file mode 100644
index 0000000000..441e77ba65
--- /dev/null
+++ b/changelog.d/15131.misc
@@ -0,0 +1 @@
+Add a new third party callback `check_event_allowed_v2` that is compatible with new batch persisting mechanisms.
\ No newline at end of file
diff --git a/docs/modules/third_party_rules_callbacks.md b/docs/modules/third_party_rules_callbacks.md
index 4a27d976fb..f301bfdcba 100644
--- a/docs/modules/third_party_rules_callbacks.md
+++ b/docs/modules/third_party_rules_callbacks.md
@@ -10,6 +10,75 @@ The available third party rules callbacks are:
### `check_event_allowed`
+_First introduced in Synapse v1.7x.x
+
+```python
+async def check_event_allowed_v2(
+ event: "synapse.events.EventBase",
+ state_events: "synapse.types.StateMap",
+) -> Tuple[bool, Optional[dict], Optional[dict]]
+```
+
+**<span style="color:red">
+This callback is very experimental and can and will break without notice. Module developers
+are encouraged to implement `check_event_for_spam` from the spam checker category instead.
+</span>**
+
+Returns:
+
+- A tuple consisting of:
+
+ - a boolean representing whether or not the event is allowed
+ - an optional dict to form the basis of a replacement event for the event
+ - an optional dict to form the basis of an additional event to be sent into the
+ room
+
+Called when processing any incoming event, with the event and a `StateMap`
+representing the current state of the room the event is being sent into. A `StateMap` is
+a dictionary that maps tuples containing an event type and a state key to the
+corresponding state event. For example retrieving the room's `m.room.create` event from
+the `state_events` argument would look like this: `state_events.get(("m.room.create", ""))`.
+The module must return a boolean indicating whether the event can be allowed.
+
+Note that this callback function processes incoming events coming via federation
+traffic (on top of client traffic). This means denying an event might cause the local
+copy of the room's history to diverge from that of remote servers. This may cause
+federation issues in the room. It is strongly recommended to only deny events using this
+callback function if the sender is a local user, or in a private federation in which all
+servers are using the same module, with the same configuration.
+
+If the boolean returned by the module is `True`, it may tell Synapse to replace the
+event with new data by returning the new event's data as a dictionary. In order to do
+that, it is recommended the module calls `event.get_dict()` to get the current event as a
+dictionary, and modify the returned dictionary accordingly.
+
+Module writers may also wish to use this check to send a second event into the room along
+with the event being checked, if this is the case the module writer must provide a dict that
+will form the basis of the event that is to be added to the room and it must be returned by `check_event_allowed_v2`.
+This dict will then be turned into an event at the appropriate time and it will be persisted after the event
+that triggered it, and if the event that triggered it is in a batch of events for persisting, it will be added to the
+end of that batch. Note that the event MAY NOT be a membership event.
+
+If `check_event_allowed_v2` raises an exception, the module is assumed to have failed.
+The event will not be accepted but is not treated as explicitly rejected, either.
+An HTTP request causing the module check will likely result in a 500 Internal
+Server Error.
+
+When the boolean returned by the module is `False`, the event is rejected.
+(Module developers should not use exceptions for rejection.)
+
+Note that replacing the event or adding an event only works for events sent by local users, not for events
+received over federation.
+
+If multiple modules implement this callback, they will be considered in order. If a
+callback returns `True`, Synapse falls through to the next one. The value of the first
+callback that does not return `True` will be used. If this happens, Synapse will not call
+any of the subsequent implementations of this callback. This callback cannot be used in conjunction with `check_event_allowed`,
+only one of these callbacks may be operational at a time - if both `check_event_allowed` and `check_event_allowed_v2`
+active only `check_event_allowed` will be executed.
+
+### `check_event_allowed`
+
_First introduced in Synapse v1.39.0_
```python
diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py
index 61d4530be7..79e2c994d6 100644
--- a/synapse/events/third_party_rules.py
+++ b/synapse/events/third_party_rules.py
@@ -32,6 +32,10 @@ logger = logging.getLogger(__name__)
CHECK_EVENT_ALLOWED_CALLBACK = Callable[
[EventBase, StateMap[EventBase]], Awaitable[Tuple[bool, Optional[dict]]]
]
+CHECK_EVENT_ALLOWED_V2_CALLBACK = Callable[
+ [EventBase, StateMap[EventBase]],
+ Awaitable[Tuple[bool, Optional[dict], Optional[dict]]],
+]
ON_CREATE_ROOM_CALLBACK = Callable[[Requester, dict, bool], Awaitable]
CHECK_THREEPID_CAN_BE_INVITED_CALLBACK = Callable[
[str, str, StateMap[EventBase]], Awaitable[bool]
@@ -155,6 +159,9 @@ class ThirdPartyEventRules:
self._storage_controllers = hs.get_storage_controllers()
self._check_event_allowed_callbacks: List[CHECK_EVENT_ALLOWED_CALLBACK] = []
+ self._check_event_allowed_v2_callbacks: List[
+ CHECK_EVENT_ALLOWED_V2_CALLBACK
+ ] = []
self._on_create_room_callbacks: List[ON_CREATE_ROOM_CALLBACK] = []
self._check_threepid_can_be_invited_callbacks: List[
CHECK_THREEPID_CAN_BE_INVITED_CALLBACK
@@ -184,6 +191,7 @@ class ThirdPartyEventRules:
def register_third_party_rules_callbacks(
self,
check_event_allowed: Optional[CHECK_EVENT_ALLOWED_CALLBACK] = None,
+ check_event_allowed_v2: Optional[CHECK_EVENT_ALLOWED_V2_CALLBACK] = None,
on_create_room: Optional[ON_CREATE_ROOM_CALLBACK] = None,
check_threepid_can_be_invited: Optional[
CHECK_THREEPID_CAN_BE_INVITED_CALLBACK
@@ -210,6 +218,9 @@ class ThirdPartyEventRules:
if check_event_allowed is not None:
self._check_event_allowed_callbacks.append(check_event_allowed)
+ if check_event_allowed_v2 is not None:
+ self._check_event_allowed_v2_callbacks.append(check_event_allowed_v2)
+
if on_create_room is not None:
self._on_create_room_callbacks.append(on_create_room)
@@ -256,7 +267,7 @@ class ThirdPartyEventRules:
self,
event: EventBase,
context: UnpersistedEventContextBase,
- ) -> Tuple[bool, Optional[dict]]:
+ ) -> Tuple[bool, Optional[dict], Optional[dict]]:
"""Check if a provided event should be allowed in the given context.
The module can return:
@@ -264,7 +275,8 @@ class ThirdPartyEventRules:
* False: the event is not allowed, and should be rejected with M_FORBIDDEN.
If the event is allowed, the module can also return a dictionary to use as a
- replacement for the event.
+ replacement for the event, and/or return a dictionary to use as the basis for
+ another event to be sent into the room.
Args:
event: The event to be checked.
@@ -274,8 +286,11 @@ class ThirdPartyEventRules:
The result from the ThirdPartyRules module, as above.
"""
# Bail out early without hitting the store if we don't have any callbacks to run.
- if len(self._check_event_allowed_callbacks) == 0:
- return True, None
+ if (
+ len(self._check_event_allowed_callbacks) == 0
+ and len(self._check_event_allowed_v2_callbacks) == 0
+ ):
+ return True, None, None
prev_state_ids = await context.get_prev_state_ids()
@@ -288,35 +303,63 @@ class ThirdPartyEventRules:
# the hashes and signatures.
event.freeze()
- for callback in self._check_event_allowed_callbacks:
- try:
- res, replacement_data = await delay_cancellation(
- callback(event, state_events)
- )
- except CancelledError:
- raise
- except SynapseError as e:
- # FIXME: Being able to throw SynapseErrors is relied upon by
- # some modules. PR #10386 accidentally broke this ability.
- # That said, we aren't keen on exposing this implementation detail
- # to modules and we should one day have a proper way to do what
- # is wanted.
- # This module callback needs a rework so that hacks such as
- # this one are not necessary.
- raise e
- except Exception:
- raise ModuleFailedException(
- "Failed to run `check_event_allowed` module API callback"
- )
+ if len(self._check_event_allowed_callbacks) != 0:
+ for callback in self._check_event_allowed_callbacks:
+ try:
+ res, replacement_data = await delay_cancellation(
+ callback(event, state_events)
+ )
+ except CancelledError:
+ raise
+ except SynapseError as e:
+ # FIXME: Being able to throw SynapseErrors is relied upon by
+ # some modules. PR #10386 accidentally broke this ability.
+ # That said, we aren't keen on exposing this implementation detail
+ # to modules and we should one day have a proper way to do what
+ # is wanted.
+ # This module callback needs a rework so that hacks such as
+ # this one are not necessary.
+ raise e
+ except Exception:
+ raise ModuleFailedException(
+ "Failed to run `check_event_allowed` module API callback"
+ )
- # Return if the event shouldn't be allowed or if the module came up with a
- # replacement dict for the event.
- if res is False:
- return res, None
- elif isinstance(replacement_data, dict):
- return True, replacement_data
+ # Return if the event shouldn't be allowed or if the module came up with a
+ # replacement dict for the event.
+ if res is False:
+ return res, None, None
+ elif isinstance(replacement_data, dict):
+ return True, replacement_data, None
+ else:
+ for v2_callback in self._check_event_allowed_v2_callbacks:
+ try:
+ res, replacement_data, new_event = await delay_cancellation(
+ v2_callback(event, state_events)
+ )
+ except CancelledError:
+ raise
+ except SynapseError as e:
+ # FIXME: Being able to throw SynapseErrors is relied upon by
+ # some modules. PR #10386 accidentally broke this ability.
+ # That said, we aren't keen on exposing this implementation detail
+ # to modules and we should one day have a proper way to do what
+ # is wanted.
+ # This module callback needs a rework so that hacks such as
+ # this one are not necessary.
+ raise e
+ except Exception:
+ raise ModuleFailedException(
+ "Failed to run `check_event_allowed_v2` module API callback"
+ )
- return True, None
+ # Return if the event shouldn't be allowed, if the module came up with a
+ # replacement dict for the event, or if the module wants to send a new event
+ if res is False:
+ return res, None, None
+ else:
+ return True, replacement_data, new_event
+ return True, None, None
async def on_create_room(
self, requester: Requester, config: dict, is_requester_admin: bool
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 80156ef343..dedcc620ac 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -1007,6 +1007,7 @@ class FederationHandler:
(
event,
unpersisted_context,
+ _,
) = await self.event_creation_handler.create_new_client_event(
builder=builder,
prev_event_ids=prev_event_ids,
@@ -1198,7 +1199,7 @@ class FederationHandler:
},
)
- event, _ = await self.event_creation_handler.create_new_client_event(
+ event, _, _ = await self.event_creation_handler.create_new_client_event(
builder=builder
)
@@ -1251,9 +1252,10 @@ class FederationHandler:
(
event,
unpersisted_context,
+ _,
) = await self.event_creation_handler.create_new_client_event(builder=builder)
- event_allowed, _ = await self.third_party_event_rules.check_event_allowed(
+ event_allowed, _, _ = await self.third_party_event_rules.check_event_allowed(
event, unpersisted_context
)
if not event_allowed:
@@ -1446,6 +1448,7 @@ class FederationHandler:
(
event,
unpersisted_context,
+ _,
) = await self.event_creation_handler.create_new_client_event(
builder=builder
)
@@ -1528,6 +1531,7 @@ class FederationHandler:
(
event,
unpersisted_context,
+ _,
) = await self.event_creation_handler.create_new_client_event(
builder=builder
)
@@ -1610,6 +1614,7 @@ class FederationHandler:
(
event,
unpersisted_context,
+ _,
) = await self.event_creation_handler.create_new_client_event(builder=builder)
EventValidator().validate_new(event, self.config)
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index b7136f8d1c..0b974afe4d 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -404,9 +404,11 @@ class FederationEventHandler:
# for knock events, we run the third-party event rules. It's not entirely clear
# why we don't do this for other sorts of membership events.
if event.membership == Membership.KNOCK:
- event_allowed, _ = await self._third_party_event_rules.check_event_allowed(
- event, context
- )
+ (
+ event_allowed,
+ _,
+ _,
+ ) = await self._third_party_event_rules.check_event_allowed(event, context)
if not event_allowed:
logger.info("Sending of knock %s forbidden by third-party rules", event)
raise SynapseError(
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index da129ec16a..d283a938c0 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -16,6 +16,7 @@
# limitations under the License.
import logging
import random
+from builtins import dict
from http import HTTPStatus
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple
@@ -577,7 +578,7 @@ class EventCreationHandler:
state_map: Optional[StateMap[str]] = None,
for_batch: bool = False,
current_state_group: Optional[int] = None,
- ) -> Tuple[EventBase, UnpersistedEventContextBase]:
+ ) -> Tuple[EventBase, UnpersistedEventContextBase, Optional[dict]]:
"""
Given a dict from a client, create a new event. If bool for_batch is true, will
create an event using the prev_event_ids, and will create an event context for
@@ -649,7 +650,9 @@ class EventCreationHandler:
exceeded
Returns:
- Tuple of created event, Context
+ Tuple of created event, Context, and an optional event dict to form the basis
+ of a new event if third_party_rules would like to send an additional event as a
+ consequence of this event.
"""
await self.auth_blocking.check_auth_blocking(requester=requester)
@@ -711,7 +714,7 @@ class EventCreationHandler:
builder.internal_metadata.historical = historical
- event, unpersisted_context = await self.create_new_client_event(
+ event, unpersisted_context, new_event = await self.create_new_client_event(
builder=builder,
requester=requester,
allow_no_prev_events=allow_no_prev_events,
@@ -765,7 +768,7 @@ class EventCreationHandler:
)
self.validator.validate_new(event, self.config)
- return event, unpersisted_context
+ return event, unpersisted_context, new_event
async def _is_exempt_from_privacy_policy(
self, builder: EventBuilder, requester: Requester
@@ -1005,7 +1008,11 @@ class EventCreationHandler:
max_retries = 5
for i in range(max_retries):
try:
- event, unpersisted_context = await self.create_event(
+ (
+ event,
+ unpersisted_context,
+ third_party_event_dict,
+ ) = await self.create_event(
requester,
event_dict,
txn_id=txn_id,
@@ -1054,9 +1061,24 @@ class EventCreationHandler:
Codes.FORBIDDEN,
)
+ events_and_context = [(event, context)]
+ if third_party_event_dict:
+ (
+ third_party_event,
+ unpersisted_third_party_context,
+ _,
+ ) = await self.create_event(
+ requester,
+ third_party_event_dict,
+ )
+ third_party_context = await unpersisted_third_party_context.persist(
+ third_party_event
+ )
+ events_and_context.append((third_party_event, third_party_context))
+
ev = await self.handle_new_client_event(
requester=requester,
- events_and_context=[(event, context)],
+ events_and_context=events_and_context,
ratelimit=ratelimit,
ignore_shadow_ban=ignore_shadow_ban,
)
@@ -1086,7 +1108,7 @@ class EventCreationHandler:
state_map: Optional[StateMap[str]] = None,
for_batch: bool = False,
current_state_group: Optional[int] = None,
- ) -> Tuple[EventBase, UnpersistedEventContextBase]:
+ ) -> Tuple[EventBase, UnpersistedEventContextBase, Optional[dict]]:
"""Create a new event for a local client. If bool for_batch is true, will
create an event using the prev_event_ids, and will create an event context for
the event using the parameters state_map and current_state_group, thus these parameters
@@ -1135,7 +1157,9 @@ class EventCreationHandler:
batch persisting
Returns:
- Tuple of created event, UnpersistedEventContext
+ Tuple of created event, UnpersistedEventContext, and an optional event dict
+ to form the basis of a new event if third_party_rules would like to send an
+ additional event as a consequence of this event.
"""
# Strip down the state_event_ids to only what we need to auth the event.
# For example, we don't need extra m.room.member that don't match event.sender
@@ -1269,9 +1293,11 @@ class EventCreationHandler:
if requester:
context.app_service = requester.app_service
- res, new_content = await self.third_party_event_rules.check_event_allowed(
- event, context
- )
+ (
+ res,
+ new_content,
+ new_event,
+ ) = await self.third_party_event_rules.check_event_allowed(event, context)
if res is False:
logger.info(
"Event %s forbidden by third-party rules",
@@ -1291,7 +1317,7 @@ class EventCreationHandler:
await self._validate_event_relation(event)
logger.debug("Created event %s", event.event_id)
- return event, context
+ return event, context, new_event
async def _validate_event_relation(self, event: EventBase) -> None:
"""
@@ -2046,7 +2072,7 @@ class EventCreationHandler:
max_retries = 5
for i in range(max_retries):
try:
- event, unpersisted_context = await self.create_event(
+ event, unpersisted_context, _ = await self.create_event(
requester,
{
"type": EventTypes.Dummy,
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index be120cb12f..f28a602741 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -213,6 +213,7 @@ class RoomCreationHandler:
(
tombstone_event,
tombstone_unpersisted_context,
+ _,
) = await self.event_creation_handler.create_event(
requester,
{
@@ -1066,7 +1067,11 @@ class RoomCreationHandler:
content: JsonDict,
for_batch: bool,
**kwargs: Any,
- ) -> Tuple[EventBase, synapse.events.snapshot.UnpersistedEventContextBase]:
+ ) -> Tuple[
+ EventBase,
+ synapse.events.snapshot.UnpersistedEventContextBase,
+ Optional[dict],
+ ]:
"""
Creates an event and associated event context.
Args:
@@ -1088,6 +1093,7 @@ class RoomCreationHandler:
(
new_event,
new_unpersisted_context,
+ third_party_event,
) = await self.event_creation_handler.create_event(
creator,
event_dict,
@@ -1103,7 +1109,7 @@ class RoomCreationHandler:
prev_event = [new_event.event_id]
state_map[(new_event.type, new_event.state_key)] = new_event.event_id
- return new_event, new_unpersisted_context
+ return new_event, new_unpersisted_context, third_party_event
visibility = room_config.get("visibility", "private")
preset_config = room_config.get(
@@ -1121,7 +1127,7 @@ class RoomCreationHandler:
)
creation_content.update({"creator": creator_id})
- creation_event, unpersisted_creation_context = await create_event(
+ creation_event, unpersisted_creation_context, _ = await create_event(
EventTypes.Create, creation_content, False
)
creation_context = await unpersisted_creation_context.persist(creation_event)
@@ -1161,14 +1167,17 @@ class RoomCreationHandler:
current_state_group = event_to_state[member_event_id]
events_to_send = []
+ third_party_events_to_append = []
# We treat the power levels override specially as this needs to be one
# of the first events that get sent into a room.
pl_content = initial_state.pop((EventTypes.PowerLevels, ""), None)
if pl_content is not None:
- power_event, power_context = await create_event(
+ power_event, power_context, power_tp_event = await create_event(
EventTypes.PowerLevels, pl_content, True
)
events_to_send.append((power_event, power_context))
+ if power_tp_event:
+ third_party_events_to_append.append(power_tp_event)
else:
power_level_content: JsonDict = {
"users": {creator_id: 100},
@@ -1211,58 +1220,92 @@ class RoomCreationHandler:
# apply those.
if power_level_content_override:
power_level_content.update(power_level_content_override)
- pl_event, pl_context = await create_event(
+ pl_event, pl_context, pl_tp_event = await create_event(
EventTypes.PowerLevels,
power_level_content,
True,
)
events_to_send.append((pl_event, pl_context))
+ if pl_tp_event:
+ third_party_events_to_append.append(pl_tp_event)
if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state:
- room_alias_event, room_alias_context = await create_event(
+ room_alias_event, room_alias_context, ra_tp_event = await create_event(
EventTypes.CanonicalAlias, {"alias": room_alias.to_string()}, True
)
events_to_send.append((room_alias_event, room_alias_context))
+ if ra_tp_event:
+ third_party_events_to_append.append(ra_tp_event)
if (EventTypes.JoinRules, "") not in initial_state:
- join_rules_event, join_rules_context = await create_event(
+ join_rules_event, join_rules_context, jr_tp_event = await create_event(
EventTypes.JoinRules,
{"join_rule": config["join_rules"]},
True,
)
events_to_send.append((join_rules_event, join_rules_context))
+ if jr_tp_event:
+ third_party_events_to_append.append(jr_tp_event)
if (EventTypes.RoomHistoryVisibility, "") not in initial_state:
- visibility_event, visibility_context = await create_event(
+ visibility_event, visibility_context, vis_tp_event = await create_event(
EventTypes.RoomHistoryVisibility,
{"history_visibility": config["history_visibility"]},
True,
)
events_to_send.append((visibility_event, visibility_context))
+ if vis_tp_event:
+ third_party_events_to_append.append(vis_tp_event)
if config["guest_can_join"]:
if (EventTypes.GuestAccess, "") not in initial_state:
- guest_access_event, guest_access_context = await create_event(
+ (
+ guest_access_event,
+ guest_access_context,
+ ga_tp_event,
+ ) = await create_event(
EventTypes.GuestAccess,
{EventContentFields.GUEST_ACCESS: GuestAccess.CAN_JOIN},
True,
)
events_to_send.append((guest_access_event, guest_access_context))
+ if ga_tp_event:
+ third_party_events_to_append.append(ga_tp_event)
for (etype, state_key), content in initial_state.items():
- event, context = await create_event(
+ event, context, tp_event = await create_event(
etype, content, True, state_key=state_key
)
events_to_send.append((event, context))
+ if tp_event:
+ third_party_events_to_append.append(tp_event)
if config["encrypted"]:
- encryption_event, encryption_context = await create_event(
+ encryption_event, encryption_context, encrypt_tp_event = await create_event(
EventTypes.RoomEncryption,
{"algorithm": RoomEncryptionAlgorithms.DEFAULT},
True,
state_key="",
)
events_to_send.append((encryption_event, encryption_context))
+ if encrypt_tp_event:
+ third_party_events_to_append.append(encrypt_tp_event)
+
+ for event_dict in third_party_events_to_append:
+ (
+ event,
+ unpersisted_context,
+ _,
+ ) = await self.event_creation_handler.create_event(
+ creator,
+ event_dict,
+ prev_event_ids=prev_event,
+ state_map=state_map,
+ for_batch=True,
+ current_state_group=current_state_group,
+ )
+ context = await unpersisted_context.persist(event)
+ events_to_send.append((event, context))
if "name" in room_config:
name = room_config["name"]
diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py
index bf9df60218..8b5e02af17 100644
--- a/synapse/handlers/room_batch.py
+++ b/synapse/handlers/room_batch.py
@@ -327,7 +327,11 @@ class RoomBatchHandler:
# Mark all events as historical
event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True
- event, unpersisted_context = await self.event_creation_handler.create_event(
+ (
+ event,
+ unpersisted_context,
+ _,
+ ) = await self.event_creation_handler.create_event(
await self.create_requester_for_user_id_from_app_service(
ev["sender"], app_service_requester.app_service
),
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 509c557889..9d3096df8d 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -418,6 +418,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
(
event,
unpersisted_context,
+ third_party_event,
) = await self.event_creation_handler.create_event(
requester,
{
@@ -472,6 +473,20 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
ratelimit=ratelimit,
)
)
+ if third_party_event:
+ (
+ tp_event,
+ tp_unpersisted_context,
+ _,
+ ) = await self.event_creation_handler.create_event(
+ requester,
+ third_party_event,
+ prev_event_ids=[result_event.event_id],
+ )
+ tp_context = await tp_unpersisted_context.persist(tp_event)
+ await self.event_creation_handler.handle_new_client_event(
+ requester, events_and_context=[(tp_event, tp_context)]
+ )
if event.membership == Membership.LEAVE:
if prev_member_event_id:
@@ -1951,6 +1966,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
(
event,
unpersisted_context,
+ third_party_event_dict,
) = await self.event_creation_handler.create_event(
requester,
event_dict,
@@ -1962,10 +1978,24 @@ class RoomMemberMasterHandler(RoomMemberHandler):
context = await unpersisted_context.persist(event)
event.internal_metadata.out_of_band_membership = True
+ events_and_context = [(event, context)]
+ if third_party_event_dict:
+ (
+ third_party_event,
+ third_party_unpersisted_context,
+ _,
+ ) = await self.event_creation_handler.create_event(
+ requester, third_party_event_dict
+ )
+ third_party_context = await third_party_unpersisted_context.persist(
+ event
+ )
+ events_and_context.append((third_party_event, third_party_context))
+
result_event = (
await self.event_creation_handler.handle_new_client_event(
requester,
- events_and_context=[(event, context)],
+ events_and_context=events_and_context,
extra_users=[UserID.from_string(target_user)],
)
)
diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py
index 9691d66b48..2e838e6572 100644
--- a/tests/handlers/test_message.py
+++ b/tests/handlers/test_message.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Tuple
+from typing import Optional, Tuple
from twisted.test.proto_helpers import MemoryReactor
@@ -81,7 +81,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
def _create_duplicate_event(
self, txn_id: str
- ) -> Tuple[EventBase, UnpersistedEventContextBase]:
+ ) -> Tuple[EventBase, UnpersistedEventContextBase, Optional[dict]]:
"""Create a new event with the given transaction ID. All events produced
by this method will be considered duplicates.
"""
@@ -109,7 +109,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
txn_id = "something_suitably_random"
- event1, unpersisted_context = self._create_duplicate_event(txn_id)
+ event1, unpersisted_context, _ = self._create_duplicate_event(txn_id)
context = self.get_success(unpersisted_context.persist(event1))
ret_event1 = self.get_success(
@@ -122,7 +122,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
self.assertEqual(event1.event_id, ret_event1.event_id)
- event2, unpersisted_context = self._create_duplicate_event(txn_id)
+ event2, unpersisted_context, _ = self._create_duplicate_event(txn_id)
context = self.get_success(unpersisted_context.persist(event2))
# We want to test that the deduplication at the persit event end works,
@@ -144,7 +144,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
# Let's test that calling `persist_event` directly also does the right
# thing.
- event3, unpersisted_context = self._create_duplicate_event(txn_id)
+ event3, unpersisted_context, _ = self._create_duplicate_event(txn_id)
context = self.get_success(unpersisted_context.persist(event3))
self.assertNotEqual(event1.event_id, event3.event_id)
@@ -160,8 +160,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
# Let's test that calling `persist_events` directly also does the right
# thing.
- event4, unpersisted_context = self._create_duplicate_event(txn_id)
+ event4, unpersisted_context, _ = self._create_duplicate_event(txn_id)
context = self.get_success(unpersisted_context.persist(event4))
+
self.assertNotEqual(event1.event_id, event3.event_id)
events, _ = self.get_success(
@@ -181,9 +182,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
txn_id = "something_else_suitably_random"
# Create two duplicate events to persist at the same time
- event1, unpersisted_context1 = self._create_duplicate_event(txn_id)
+ event1, unpersisted_context1, _ = self._create_duplicate_event(txn_id)
context1 = self.get_success(unpersisted_context1.persist(event1))
- event2, unpersisted_context2 = self._create_duplicate_event(txn_id)
+ event2, unpersisted_context2, _ = self._create_duplicate_event(txn_id)
context2 = self.get_success(unpersisted_context2.persist(event2))
# Ensure their event IDs are different to start with
@@ -209,7 +210,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
memberEvent, _ = self._create_and_persist_member_event()
# Try to create the event with empty prev_events bit with some auth_events
- event, _ = self.get_success(
+ event, _, _ = self.get_success(
self.handler.create_event(
self.requester,
{
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index aff1ec4758..161ff0a6c1 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -507,7 +507,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
# Lower the permissions of the inviter.
event_creation_handler = self.hs.get_event_creation_handler()
requester = create_requester(inviter)
- event, unpersisted_context = self.get_success(
+
+ event, unpersisted_context, _ = self.get_success(
event_creation_handler.create_event(
requester,
{
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index da4d240826..ce095eb68a 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -965,7 +965,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
},
)
- event, unpersisted_context = self.get_success(
+ event, unpersisted_context, _ = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
diff --git a/tests/push/test_bulk_push_rule_evaluator.py b/tests/push/test_bulk_push_rule_evaluator.py
index 46df0102f7..978c2d5a34 100644
--- a/tests/push/test_bulk_push_rule_evaluator.py
+++ b/tests/push/test_bulk_push_rule_evaluator.py
@@ -130,7 +130,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
# Create a new message event, and try to evaluate it under the dodgy
# power level event.
- event, unpersisted_context = self.get_success(
+ event, unpersisted_context, _ = self.get_success(
self.event_creation_handler.create_event(
self.requester,
{
@@ -171,7 +171,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
"""Ensure that push rules are not calculated when disabled in the config"""
# Create a new message event which should cause a notification.
- event, unpersisted_context = self.get_success(
+ event, unpersisted_context, _ = self.get_success(
self.event_creation_handler.create_event(
self.requester,
{
@@ -202,7 +202,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
) -> bool:
"""Returns true iff the `mentions` trigger an event push action."""
# Create a new message event which should cause a notification.
- event, unpersisted_context = self.get_success(
+ event, unpersisted_context, _ = self.get_success(
self.event_creation_handler.create_event(
self.requester,
{
@@ -378,7 +378,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
bulk_evaluator = BulkPushRuleEvaluator(self.hs)
# Create & persist an event to use as the parent of the relation.
- event, unpersisted_context = self.get_success(
+ event, unpersisted_context, _ = self.get_success(
self.event_creation_handler.create_event(
self.requester,
{
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 4b8f889a71..c278f6bbad 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -2935,7 +2935,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
},
)
- event, unpersisted_context = self.get_success(
+ event, unpersisted_context, _ = self.get_success(
event_creation_handler.create_new_client_event(builder)
)
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index 753ecc8d16..1bdb6bb6a5 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -275,6 +275,46 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
ev = channel.json_body
self.assertEqual(ev["content"]["x"], "y")
+ def test_add_event(self) -> None:
+ # needs checking of combo of return conditions, ie replace event and send event
+ async def check(
+ ev: EventBase, state: StateMap[EventBase]
+ ) -> Tuple[bool, Optional[JsonDict], Optional[dict]]:
+ event_dict = {
+ "type": "m.room.test",
+ "room_id": self.room_id,
+ "sender": self.user_id,
+ "content": {
+ "creator": "test_user",
+ "body": "message",
+ "msgtype": "message",
+ },
+ }
+ if ev.type == "message":
+ return True, None, event_dict
+ else:
+ return True, None, None
+
+ self.hs.get_third_party_event_rules()._check_event_allowed_v2_callbacks = [
+ check
+ ]
+
+ channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/rooms/%s/send/message/1" % self.room_id,
+ {"x": "x"},
+ access_token=self.tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ events = self.get_success(
+ self.hs.get_datastores().main.get_forward_extremities_for_room(self.room_id)
+ )
+ event = events[1]
+
+ e = self.get_success(self.hs.get_datastores().main.get_event(event["event_id"]))
+ self.assertEqual("m.room.test", e.type)
+
def test_message_edit(self) -> None:
"""Ensure that the module doesn't cause issues with edited messages."""
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index e39b63edac..2a9aa9e21c 100644
--- a/tests/storage/test_event_chain.py
+++ b/tests/storage/test_event_chain.py
@@ -522,7 +522,8 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
latest_event_ids = self.get_success(
self.store.get_prev_events_for_room(room_id)
)
- event, unpersisted_context = self.get_success(
+
+ event, unpersisted_context, _ = self.get_success(
event_handler.create_event(
self.requester,
{
@@ -545,7 +546,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
assert state_ids1 is not None
state1 = set(state_ids1.values())
- event, unpersisted_context = self.get_success(
+ event, unpersisted_context, _ = self.get_success(
event_handler.create_event(
self.requester,
{
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index 0100f7da14..b8b997f38a 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -74,7 +74,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
},
)
- event, unpersisted_context = self.get_success(
+ event, unpersisted_context, _ = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
@@ -98,7 +98,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
},
)
- event, unpersisted_context = self.get_success(
+ event, unpersisted_context, _ = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
@@ -123,7 +123,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
},
)
- event, unpersisted_context = self.get_success(
+ event, unpersisted_context, _ = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
@@ -265,7 +265,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
def internal_metadata(self) -> _EventInternalMetadata:
return self._base_builder.internal_metadata
- event_1, unpersisted_context_1 = self.get_success(
+ event_1, unpersisted_context_1, _ = self.get_success(
self.event_creation_handler.create_new_client_event(
cast(
EventBuilder,
@@ -290,7 +290,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.get_success(self._persistence.persist_event(event_1, context_1))
- event_2, unpersisted_context_2 = self.get_success(
+ event_2, unpersisted_context_2, _ = self.get_success(
self.event_creation_handler.create_new_client_event(
cast(
EventBuilder,
@@ -431,7 +431,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
},
)
- redaction_event, unpersisted_context = self.get_success(
+ redaction_event, unpersisted_context, _ = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index 62aed6af0a..1a1214c7a2 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -67,7 +67,7 @@ class StateStoreTestCase(HomeserverTestCase):
},
)
- event, unpersisted_context = self.get_success(
+ event, unpersisted_context, _ = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
@@ -521,7 +521,7 @@ class StateStoreTestCase(HomeserverTestCase):
},
)
- event1, unpersisted_context1 = self.get_success(
+ event1, unpersisted_context1, _ = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
events_and_context.append((event1, unpersisted_context1))
@@ -537,7 +537,7 @@ class StateStoreTestCase(HomeserverTestCase):
},
)
- event2, unpersisted_context2 = self.get_success(
+ event2, unpersisted_context2, _ = self.get_success(
self.event_creation_handler.create_new_client_event(builder2)
)
events_and_context.append((event2, unpersisted_context2))
@@ -552,7 +552,7 @@ class StateStoreTestCase(HomeserverTestCase):
},
)
- event3, unpersisted_context3 = self.get_success(
+ event3, unpersisted_context3, _ = self.get_success(
self.event_creation_handler.create_new_client_event(builder3)
)
events_and_context.append((event3, unpersisted_context3))
@@ -568,7 +568,7 @@ class StateStoreTestCase(HomeserverTestCase):
},
)
- event4, unpersisted_context4 = self.get_success(
+ event4, unpersisted_context4, _ = self.get_success(
self.event_creation_handler.create_new_client_event(builder4)
)
events_and_context.append((event4, unpersisted_context4))
diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py
index 9679904c33..c619ef7f38 100644
--- a/tests/test_utils/event_injection.py
+++ b/tests/test_utils/event_injection.py
@@ -95,6 +95,7 @@ async def create_event(
(
event,
unpersisted_context,
+ _,
) = await hs.get_event_creation_handler().create_new_client_event(
builder, prev_event_ids=prev_event_ids
)
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index 9ed330f554..6004490b8c 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -207,7 +207,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
},
)
- event, unpersisted_context = self.get_success(
+ event, unpersisted_context, _ = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
context = self.get_success(unpersisted_context.persist(event))
@@ -233,7 +233,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
},
)
- event, unpersisted_context = self.get_success(
+ event, unpersisted_context, _ = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
context = self.get_success(unpersisted_context.persist(event))
@@ -256,7 +256,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
},
)
- event, unpersisted_context = self.get_success(
+ event, unpersisted_context, _ = self.get_success(
self.event_creation_handler.create_new_client_event(builder)
)
context = self.get_success(unpersisted_context.persist(event))
diff --git a/tests/unittest.py b/tests/unittest.py
index f9160faa1d..4b31f84494 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -723,7 +723,7 @@ class HomeserverTestCase(TestCase):
event_creator = self.hs.get_event_creation_handler()
requester = create_requester(user)
- event, unpersisted_context = self.get_success(
+ event, unpersisted_context, _ = self.get_success(
event_creator.create_event(
requester,
{
diff --git a/tests/utils.py b/tests/utils.py
index a0ac11bc5c..3badfb7d41 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -335,9 +335,11 @@ async def create_room(hs: HomeServer, room_id: str, creator_id: str) -> None:
},
)
- event, unpersisted_context = await event_creation_handler.create_new_client_event(
- builder
- )
+ (
+ event,
+ unpersisted_context,
+ _,
+ ) = await event_creation_handler.create_new_client_event(builder)
context = await unpersisted_context.persist(event)
await persistence_store.persist_event(event, context)
|