diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py
index 69d384442f..9691d66b48 100644
--- a/tests/handlers/test_message.py
+++ b/tests/handlers/test_message.py
@@ -18,7 +18,7 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes
from synapse.events import EventBase
-from synapse.events.snapshot import EventContext
+from synapse.events.snapshot import EventContext, UnpersistedEventContextBase
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
@@ -79,7 +79,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
return memberEvent, memberEventContext
- def _create_duplicate_event(self, txn_id: str) -> Tuple[EventBase, EventContext]:
+ def _create_duplicate_event(
+ self, txn_id: str
+ ) -> Tuple[EventBase, UnpersistedEventContextBase]:
"""Create a new event with the given transaction ID. All events produced
by this method will be considered duplicates.
"""
@@ -107,7 +109,8 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
txn_id = "something_suitably_random"
- event1, context = self._create_duplicate_event(txn_id)
+ event1, unpersisted_context = self._create_duplicate_event(txn_id)
+ context = self.get_success(unpersisted_context.persist(event1))
ret_event1 = self.get_success(
self.handler.handle_new_client_event(
@@ -119,7 +122,8 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
self.assertEqual(event1.event_id, ret_event1.event_id)
- event2, context = self._create_duplicate_event(txn_id)
+ event2, unpersisted_context = self._create_duplicate_event(txn_id)
+ context = self.get_success(unpersisted_context.persist(event2))
# We want to test that the deduplication at the persit event end works,
# so we want to make sure we test with different events.
@@ -140,7 +144,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
# Let's test that calling `persist_event` directly also does the right
# thing.
- event3, context = self._create_duplicate_event(txn_id)
+ event3, unpersisted_context = self._create_duplicate_event(txn_id)
+ context = self.get_success(unpersisted_context.persist(event3))
+
self.assertNotEqual(event1.event_id, event3.event_id)
ret_event3, event_pos3, _ = self.get_success(
@@ -154,7 +160,8 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
# Let's test that calling `persist_events` directly also does the right
# thing.
- event4, context = self._create_duplicate_event(txn_id)
+ event4, unpersisted_context = self._create_duplicate_event(txn_id)
+ context = self.get_success(unpersisted_context.persist(event4))
self.assertNotEqual(event1.event_id, event3.event_id)
events, _ = self.get_success(
@@ -174,8 +181,10 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
txn_id = "something_else_suitably_random"
# Create two duplicate events to persist at the same time
- event1, context1 = self._create_duplicate_event(txn_id)
- event2, context2 = self._create_duplicate_event(txn_id)
+ event1, unpersisted_context1 = self._create_duplicate_event(txn_id)
+ context1 = self.get_success(unpersisted_context1.persist(event1))
+ event2, unpersisted_context2 = self._create_duplicate_event(txn_id)
+ context2 = self.get_success(unpersisted_context2.persist(event2))
# Ensure their event IDs are different to start with
self.assertNotEqual(event1.event_id, event2.event_id)
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index 1db99b3c00..aff1ec4758 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -507,7 +507,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
# Lower the permissions of the inviter.
event_creation_handler = self.hs.get_event_creation_handler()
requester = create_requester(inviter)
- event, context = self.get_success(
+ event, unpersisted_context = self.get_success(
event_creation_handler.create_event(
requester,
{
@@ -519,6 +519,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
},
)
)
+ context = self.get_success(unpersisted_context.persist(event))
self.get_success(
event_creation_handler.handle_new_client_event(
requester, events_and_context=[(event, context)]
diff --git a/tests/push/test_bulk_push_rule_evaluator.py b/tests/push/test_bulk_push_rule_evaluator.py
index dce6899e78..1458076a90 100644
--- a/tests/push/test_bulk_push_rule_evaluator.py
+++ b/tests/push/test_bulk_push_rule_evaluator.py
@@ -130,7 +130,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
# Create a new message event, and try to evaluate it under the dodgy
# power level event.
- event, context = self.get_success(
+ event, unpersisted_context = self.get_success(
self.event_creation_handler.create_event(
self.requester,
{
@@ -145,6 +145,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
prev_event_ids=[pl_event_id],
)
)
+ context = self.get_success(unpersisted_context.persist(event))
bulk_evaluator = BulkPushRuleEvaluator(self.hs)
# should not raise
@@ -170,7 +171,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
"""Ensure that push rules are not calculated when disabled in the config"""
# Create a new message event which should cause a notification.
- event, context = self.get_success(
+ event, unpersisted_context = self.get_success(
self.event_creation_handler.create_event(
self.requester,
{
@@ -184,6 +185,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
},
)
)
+ context = self.get_success(unpersisted_context.persist(event))
bulk_evaluator = BulkPushRuleEvaluator(self.hs)
# Mock the method which calculates push rules -- we do this instead of
@@ -200,7 +202,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
) -> bool:
"""Returns true iff the `mentions` trigger an event push action."""
# Create a new message event which should cause a notification.
- event, context = self.get_success(
+ event, unpersisted_context = self.get_success(
self.event_creation_handler.create_event(
self.requester,
{
@@ -211,7 +213,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
},
)
)
-
+ context = self.get_success(unpersisted_context.persist(event))
# Execute the push rule machinery.
self.get_success(bulk_evaluator.action_for_events_by_user([(event, context)]))
@@ -390,7 +392,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
bulk_evaluator = BulkPushRuleEvaluator(self.hs)
# Create & persist an event to use as the parent of the relation.
- event, context = self.get_success(
+ event, unpersisted_context = self.get_success(
self.event_creation_handler.create_event(
self.requester,
{
@@ -404,6 +406,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
},
)
)
+ context = self.get_success(unpersisted_context.persist(event))
self.get_success(
self.event_creation_handler.handle_new_client_event(
self.requester, events_and_context=[(event, context)]
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index 4dd763096d..a4900703c4 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -713,7 +713,7 @@ class RoomsCreateTestCase(RoomBase):
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
self.assertTrue("room_id" in channel.json_body)
assert channel.resource_usage is not None
- self.assertEqual(33, channel.resource_usage.db_txn_count)
+ self.assertEqual(30, channel.resource_usage.db_txn_count)
def test_post_room_initial_state(self) -> None:
# POST with initial_state config key, expect new room id
@@ -726,7 +726,7 @@ class RoomsCreateTestCase(RoomBase):
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
self.assertTrue("room_id" in channel.json_body)
assert channel.resource_usage is not None
- self.assertEqual(36, channel.resource_usage.db_txn_count)
+ self.assertEqual(32, channel.resource_usage.db_txn_count)
def test_post_room_visibility_key(self) -> None:
# POST with visibility config key, expect new room id
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index 73d11e7786..e39b63edac 100644
--- a/tests/storage/test_event_chain.py
+++ b/tests/storage/test_event_chain.py
@@ -522,7 +522,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
latest_event_ids = self.get_success(
self.store.get_prev_events_for_room(room_id)
)
- event, context = self.get_success(
+ event, unpersisted_context = self.get_success(
event_handler.create_event(
self.requester,
{
@@ -535,6 +535,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
prev_event_ids=latest_event_ids,
)
)
+ context = self.get_success(unpersisted_context.persist(event))
self.get_success(
event_handler.handle_new_client_event(
self.requester, events_and_context=[(event, context)]
@@ -544,7 +545,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
assert state_ids1 is not None
state1 = set(state_ids1.values())
- event, context = self.get_success(
+ event, unpersisted_context = self.get_success(
event_handler.create_event(
self.requester,
{
@@ -557,6 +558,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
prev_event_ids=latest_event_ids,
)
)
+ context = self.get_success(unpersisted_context.persist(event))
self.get_success(
event_handler.handle_new_client_event(
self.requester, events_and_context=[(event, context)]
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index e82c03f597..62aed6af0a 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -496,3 +496,129 @@ class StateStoreTestCase(HomeserverTestCase):
self.assertEqual(is_all, True)
self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
+
+ def test_batched_state_group_storing(self) -> None:
+ creation_event = self.inject_state_event(
+ self.room, self.u_alice, EventTypes.Create, "", {}
+ )
+ state_to_event = self.get_success(
+ self.storage.state.get_state_groups(
+ self.room.to_string(), [creation_event.event_id]
+ )
+ )
+ current_state_group = list(state_to_event.keys())[0]
+
+ # create some unpersisted events and event contexts to store against room
+ events_and_context = []
+ builder = self.event_builder_factory.for_room_version(
+ RoomVersions.V1,
+ {
+ "type": EventTypes.Name,
+ "sender": self.u_alice.to_string(),
+ "state_key": "",
+ "room_id": self.room.to_string(),
+ "content": {"name": "first rename of room"},
+ },
+ )
+
+ event1, unpersisted_context1 = self.get_success(
+ self.event_creation_handler.create_new_client_event(builder)
+ )
+ events_and_context.append((event1, unpersisted_context1))
+
+ builder2 = self.event_builder_factory.for_room_version(
+ RoomVersions.V1,
+ {
+ "type": EventTypes.JoinRules,
+ "sender": self.u_alice.to_string(),
+ "state_key": "",
+ "room_id": self.room.to_string(),
+ "content": {"join_rule": "private"},
+ },
+ )
+
+ event2, unpersisted_context2 = self.get_success(
+ self.event_creation_handler.create_new_client_event(builder2)
+ )
+ events_and_context.append((event2, unpersisted_context2))
+
+ builder3 = self.event_builder_factory.for_room_version(
+ RoomVersions.V1,
+ {
+ "type": EventTypes.Message,
+ "sender": self.u_alice.to_string(),
+ "room_id": self.room.to_string(),
+ "content": {"body": "hello from event 3", "msgtype": "m.text"},
+ },
+ )
+
+ event3, unpersisted_context3 = self.get_success(
+ self.event_creation_handler.create_new_client_event(builder3)
+ )
+ events_and_context.append((event3, unpersisted_context3))
+
+ builder4 = self.event_builder_factory.for_room_version(
+ RoomVersions.V1,
+ {
+ "type": EventTypes.JoinRules,
+ "sender": self.u_alice.to_string(),
+ "state_key": "",
+ "room_id": self.room.to_string(),
+ "content": {"join_rule": "public"},
+ },
+ )
+
+ event4, unpersisted_context4 = self.get_success(
+ self.event_creation_handler.create_new_client_event(builder4)
+ )
+ events_and_context.append((event4, unpersisted_context4))
+
+ processed_events_and_context = self.get_success(
+ self.hs.get_datastores().state.store_state_deltas_for_batched(
+ events_and_context, self.room.to_string(), current_state_group
+ )
+ )
+
+ # check that only state events are in state_groups, and all state events are in state_groups
+ res = self.get_success(
+ self.store.db_pool.simple_select_list(
+ table="state_groups",
+ keyvalues=None,
+ retcols=("event_id",),
+ )
+ )
+
+ events = []
+ for result in res:
+ self.assertNotIn(event3.event_id, result)
+ events.append(result.get("event_id"))
+
+ for event, _ in processed_events_and_context:
+ if event.is_state():
+ self.assertIn(event.event_id, events)
+
+ # check that each unique state has state group in state_groups_state and that the
+ # type/state key is correct, and check that each state event's state group
+ # has an entry and prev event in state_group_edges
+ for event, context in processed_events_and_context:
+ if event.is_state():
+ state = self.get_success(
+ self.store.db_pool.simple_select_list(
+ table="state_groups_state",
+ keyvalues={"state_group": context.state_group_after_event},
+ retcols=("type", "state_key"),
+ )
+ )
+ self.assertEqual(event.type, state[0].get("type"))
+ self.assertEqual(event.state_key, state[0].get("state_key"))
+
+ groups = self.get_success(
+ self.store.db_pool.simple_select_list(
+ table="state_group_edges",
+ keyvalues={"state_group": str(context.state_group_after_event)},
+ retcols=("*",),
+ )
+ )
+ self.assertEqual(
+ context.state_group_before_event, groups[0].get("prev_state_group")
+ )
diff --git a/tests/unittest.py b/tests/unittest.py
index b21e7f1221..f9160faa1d 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -723,7 +723,7 @@ class HomeserverTestCase(TestCase):
event_creator = self.hs.get_event_creation_handler()
requester = create_requester(user)
- event, context = self.get_success(
+ event, unpersisted_context = self.get_success(
event_creator.create_event(
requester,
{
@@ -735,7 +735,7 @@ class HomeserverTestCase(TestCase):
prev_event_ids=prev_event_ids,
)
)
-
+ context = self.get_success(unpersisted_context.persist(event))
if soft_failed:
event.internal_metadata.soft_failed = True
|