diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 9b02ce0dfd..47dcc6544d 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -133,7 +133,7 @@ class MessageHandler(BaseRoomHandler):
if stamp_event:
event.content["hsob_ts"] = int(self.clock.time_msec())
- yield self.state_handler.handle_new_event(event)
+ yield self.state_handler.handle_new_event(event, snapshot)
yield self._on_new_room_event(event, snapshot)
@@ -362,6 +362,13 @@ class RoomCreationHandler(BaseRoomHandler):
content=config,
)
+ snapshot = yield self.store.snapshot_room(
+ room_id=room_id,
+ user_id=user_id,
+ state_type=RoomConfigEvent.TYPE,
+ state_key="",
+ )
+
if room_alias:
yield self.store.create_room_alias_association(
room_id=room_id,
@@ -369,11 +376,11 @@ class RoomCreationHandler(BaseRoomHandler):
servers=[self.hs.hostname],
)
- yield self.state_handler.handle_new_event(config_event)
+ yield self.state_handler.handle_new_event(config_event, snapshot)
# store_id = persist...
federation_handler = self.hs.get_handlers().federation_handler
- yield federation_handler.handle_new_event(config_event)
+ yield federation_handler.handle_new_event(config_event, snapshot)
# self.notifier.on_new_room_event(event, store_id)
content = {"membership": Membership.JOIN}
diff --git a/synapse/state.py b/synapse/state.py
index ca8e1ca630..e1a1a159bb 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -45,7 +45,7 @@ class StateHandler(object):
@defer.inlineCallbacks
@log_function
- def handle_new_event(self, event):
+ def handle_new_event(self, event, snapshot):
""" Given an event this works out if a) we have sufficient power level
to update the state and b) works out what the prev_state should be.
@@ -70,25 +70,13 @@ class StateHandler(object):
# Now I need to fill out the prev state and work out if it has auth
# (w.r.t. to power levels)
- results = yield self.store.get_latest_pdus_in_context(
- event.room_id
- )
+ snapshot.fill_out_prev_events(event)
event.prev_events = [
- encode_event_id(p_id, origin) for p_id, origin, _ in results
- ]
- event.prev_events = [
e for e in event.prev_events if e != event.event_id
]
- if results:
- event.depth = max([int(v) for _, _, v in results]) + 1
- else:
- event.depth = 0
-
- current_state = yield self.store.get_current_state_pdu(
- key.context, key.type, key.state_key
- )
+ current_state = snapshot.prev_state_pdu
if current_state:
event.prev_state = encode_event_id(
diff --git a/tests/handlers/test_room.py b/tests/handlers/test_room.py
index 613f5c307e..a84dbcc471 100644
--- a/tests/handlers/test_room.py
+++ b/tests/handlers/test_room.py
@@ -330,6 +330,7 @@ class RoomCreationTest(unittest.TestCase):
db_pool=None,
datastore=NonCallableMock(spec_set=[
"store_room",
+ "snapshot_room",
]),
http_client=NonCallableMock(spec_set=[]),
notifier=NonCallableMock(spec_set=["on_new_room_event"]),
diff --git a/tests/test_state.py b/tests/test_state.py
index e64d15a3a2..58fd0bf3be 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -243,21 +243,24 @@ class StateTestCase(unittest.TestCase):
state_pdu = new_fake_pdu_entry("C", "test", "mem", "x", "A", 20)
- tup = ("pdu_id", "origin.com", 5)
- pdus = [tup]
+ snapshot = Mock()
+ snapshot.prev_state_pdu = state_pdu
+ event_id = "pdu_id@origin.com"
- self.persistence.get_latest_pdus_in_context.return_value = pdus
- self.persistence.get_current_state_pdu.return_value = state_pdu
+ def fill_out_prev_events(event):
+ event.prev_events = [event_id]
+ event.depth = 6
+ snapshot.fill_out_prev_events = fill_out_prev_events
- yield self.state.handle_new_event(event)
+ yield self.state.handle_new_event(event, snapshot)
- self.assertLess(tup[2], event.depth)
+ self.assertLess(5, event.depth)
self.assertEquals(1, len(event.prev_events))
prev_id = event.prev_events[0]
- self.assertEqual(encode_event_id(tup[0], tup[1]), prev_id)
+ self.assertEqual(event_id, prev_id)
self.assertEqual(
encode_event_id(state_pdu.pdu_id, state_pdu.origin),
|