summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--tests/handlers/test_room.py68
1 files changed, 34 insertions, 34 deletions
diff --git a/tests/handlers/test_room.py b/tests/handlers/test_room.py
index 0279ab703a..1e6e7c3602 100644
--- a/tests/handlers/test_room.py
+++ b/tests/handlers/test_room.py
@@ -20,7 +20,7 @@ from tests import unittest
 from synapse.api.events.room import (
     RoomMemberEvent,
 )
-from synapse.api.constants import Membership
+from synapse.api.constants import EventTypes, Membership
 from synapse.handlers.room import RoomMemberHandler, RoomCreationHandler
 from synapse.handlers.profile import ProfileHandler
 from synapse.server import HomeServer
@@ -254,13 +254,9 @@ class RoomCreationTest(unittest.TestCase):
             notifier=NonCallableMock(spec_set=["on_new_room_event"]),
             handlers=NonCallableMock(spec_set=[
                 "room_creation_handler",
-                "room_member_handler",
-                "federation_handler",
+                "message_handler",
             ]),
             auth=NonCallableMock(spec_set=["check", "add_auth_events"]),
-            state_handler=NonCallableMock(spec_set=[
-                "annotate_event_with_state",
-            ]),
             ratelimiter=NonCallableMock(spec_set=[
                 "send_message",
             ]),
@@ -271,30 +267,12 @@ class RoomCreationTest(unittest.TestCase):
             "handle_new_event",
         ])
 
-        self.datastore = hs.get_datastore()
         self.handlers = hs.get_handlers()
-        self.notifier = hs.get_notifier()
-        self.state_handler = hs.get_state_handler()
-        self.hs = hs
-
-        self.handlers.federation_handler = self.federation
 
-        self.handlers.room_creation_handler = RoomCreationHandler(self.hs)
+        self.handlers.room_creation_handler = RoomCreationHandler(hs)
         self.room_creation_handler = self.handlers.room_creation_handler
 
-        self.handlers.room_member_handler = NonCallableMock(spec_set=[
-            "change_membership"
-        ])
-        self.room_member_handler = self.handlers.room_member_handler
-
-        def annotate(event):
-            event.state_events = {}
-            return defer.succeed(None)
-        self.state_handler.annotate_event_with_state.side_effect = annotate
-
-        def hosts(room):
-            return defer.succeed([])
-        self.datastore.get_joined_hosts_for_room.side_effect = hosts
+        self.message_handler = self.handlers.message_handler
 
         self.ratelimiter = hs.get_ratelimiter()
         self.ratelimiter.send_message.return_value = (True, 0)
@@ -311,14 +289,36 @@ class RoomCreationTest(unittest.TestCase):
             config=config,
         )
 
-        self.assertTrue(self.room_member_handler.change_membership.called)
-        join_event = self.room_member_handler.change_membership.call_args[0][0]
+        self.assertTrue(self.message_handler.handle_event.called)
+
+        event_dicts = [
+            e[0][0] for e in self.message_handler.handle_event.call_args_list
+        ]
 
-        self.assertEquals(RoomMemberEvent.TYPE, join_event.type)
-        self.assertEquals(room_id, join_event.room_id)
-        self.assertEquals(user_id, join_event.user_id)
-        self.assertEquals(user_id, join_event.state_key)
+        self.assertTrue(len(event_dicts) > 3)
 
-        self.assertTrue(self.state_handler.annotate_event_with_state.called)
+        self.assertDictContainsSubset(
+            {
+                "type": EventTypes.Create,
+                "sender": user_id,
+                "room_id": room_id,
+            },
+            event_dicts[0]
+        )
+
+        self.assertEqual(user_id, event_dicts[0]["content"]["creator"])
 
-        self.assertTrue(self.federation.handle_new_event.called)
+        self.assertDictContainsSubset(
+            {
+                "type": EventTypes.Member,
+                "sender": user_id,
+                "room_id": room_id,
+                "state_key": user_id,
+            },
+            event_dicts[1]
+        )
+
+        self.assertEqual(
+            Membership.JOIN,
+            event_dicts[1]["content"]["membership"]
+        )