summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/events/third_party_rules.py12
-rw-r--r--tests/rest/client/test_rooms.py4
2 files changed, 12 insertions, 4 deletions
diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py
index 72ab696898..7ee48cc370 100644
--- a/synapse/events/third_party_rules.py
+++ b/synapse/events/third_party_rules.py
@@ -231,7 +231,11 @@ class ThirdPartyEventRules:
             self._on_threepid_bind_callbacks.append(on_threepid_bind)
 
     async def check_event_allowed(
-        self, event: EventBase, context: EventContext
+        self,
+        event: EventBase,
+        context: EventContext,
+        for_batch: bool = False,
+        state_map: Optional[StateMap[str]] = None,
     ) -> Tuple[bool, Optional[dict]]:
         """Check if a provided event should be allowed in the given context.
 
@@ -253,7 +257,11 @@ class ThirdPartyEventRules:
         if len(self._check_event_allowed_callbacks) == 0:
             return True, None
 
-        prev_state_ids = await context.get_prev_state_ids()
+        if for_batch:
+            assert state_map is not None
+            prev_state_ids = state_map
+        else:
+            prev_state_ids = await context.get_prev_state_ids()
 
         # Retrieve the state events from the database.
         events = await self.store.get_events(prev_state_ids.values())
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index 716366eb90..255d1dfb8f 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -714,7 +714,7 @@ class RoomsCreateTestCase(RoomBase):
         self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
         self.assertTrue("room_id" in channel.json_body)
         assert channel.resource_usage is not None
-        self.assertEqual(34, channel.resource_usage.db_txn_count)
+        self.assertEqual(32, channel.resource_usage.db_txn_count)
 
     def test_post_room_initial_state(self) -> None:
         # POST with initial_state config key, expect new room id
@@ -727,7 +727,7 @@ class RoomsCreateTestCase(RoomBase):
         self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
         self.assertTrue("room_id" in channel.json_body)
         assert channel.resource_usage is not None
-        self.assertEqual(37, channel.resource_usage.db_txn_count)
+        self.assertEqual(34, channel.resource_usage.db_txn_count)
 
     def test_post_room_visibility_key(self) -> None:
         # POST with visibility config key, expect new room id