diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py
index 0ee5eee385..76ab83d1f7 100644
--- a/tests/handlers/test_message.py
+++ b/tests/handlers/test_message.py
@@ -24,6 +24,7 @@ from typing import Tuple
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes
+from synapse.api.errors import SynapseError
from synapse.events import EventBase
from synapse.events.snapshot import EventContext, UnpersistedEventContextBase
from synapse.rest import admin
@@ -51,11 +52,15 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
persistence = self.hs.get_storage_controllers().persistence
assert persistence is not None
self._persist_event_storage_controller = persistence
+ self.store = self.hs.get_datastores().main
self.user_id = self.register_user("tester", "foobar")
device_id = "dev-1"
access_token = self.login("tester", "foobar", device_id=device_id)
self.room_id = self.helper.create_room_as(self.user_id, tok=access_token)
+ self.private_room_id = self.helper.create_room_as(
+ self.user_id, tok=access_token, extra_content={"preset": "private_chat"}
+ )
self.requester = create_requester(self.user_id, device_id=device_id)
@@ -285,6 +290,41 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
AssertionError,
)
+ def test_call_invite_event_creation_fails_in_public_room(self) -> None:
+ # get prev_events for room
+ prev_events = self.get_success(
+ self.store.get_prev_events_for_room(self.room_id)
+ )
+
+ # the invite in a public room should fail
+ self.get_failure(
+ self.handler.create_event(
+ self.requester,
+ {
+ "type": EventTypes.CallInvite,
+ "room_id": self.room_id,
+ "sender": self.requester.user.to_string(),
+ },
+ prev_event_ids=prev_events,
+ auth_event_ids=prev_events,
+ ),
+ SynapseError,
+ )
+
+ # but a call invite in a private room should succeed
+ self.get_success(
+ self.handler.create_event(
+ self.requester,
+ {
+ "type": EventTypes.CallInvite,
+ "room_id": self.private_room_id,
+ "sender": self.requester.user.to_string(),
+ },
+ prev_event_ids=prev_events,
+ auth_event_ids=prev_events,
+ )
+ )
+
class ServerAclValidationTestCase(unittest.HomeserverTestCase):
servlets = [
|