diff options
-rw-r--r-- | synapse/events/third_party_rules.py | 12 | ||||
-rw-r--r-- | tests/rest/client/test_rooms.py | 4 |
2 files changed, 12 insertions, 4 deletions
diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py index 72ab696898..7ee48cc370 100644 --- a/synapse/events/third_party_rules.py +++ b/synapse/events/third_party_rules.py @@ -231,7 +231,11 @@ class ThirdPartyEventRules: self._on_threepid_bind_callbacks.append(on_threepid_bind) async def check_event_allowed( - self, event: EventBase, context: EventContext + self, + event: EventBase, + context: EventContext, + for_batch: bool = False, + state_map: Optional[StateMap[str]] = None, ) -> Tuple[bool, Optional[dict]]: """Check if a provided event should be allowed in the given context. @@ -253,7 +257,11 @@ class ThirdPartyEventRules: if len(self._check_event_allowed_callbacks) == 0: return True, None - prev_state_ids = await context.get_prev_state_ids() + if for_batch: + assert state_map is not None + prev_state_ids = state_map + else: + prev_state_ids = await context.get_prev_state_ids() # Retrieve the state events from the database. events = await self.store.get_events(prev_state_ids.values()) diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 716366eb90..255d1dfb8f 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -714,7 +714,7 @@ class RoomsCreateTestCase(RoomBase): self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) assert channel.resource_usage is not None - self.assertEqual(34, channel.resource_usage.db_txn_count) + self.assertEqual(32, channel.resource_usage.db_txn_count) def test_post_room_initial_state(self) -> None: # POST with initial_state config key, expect new room id @@ -727,7 +727,7 @@ class RoomsCreateTestCase(RoomBase): self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) assert channel.resource_usage is not None - self.assertEqual(37, channel.resource_usage.db_txn_count) + self.assertEqual(34, channel.resource_usage.db_txn_count) def test_post_room_visibility_key(self) -> None: # POST with visibility config key, expect new room id |