diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 6b8fed4502..2d62fc2ed0 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -242,8 +242,8 @@ class DataStore(RoomMemberStore, RoomStore,
"state_key": event.state_key,
}
- if hasattr(event, "prev_state"):
- vals["prev_state"] = event.prev_state
+ if hasattr(event, "replaces_state"):
+ vals["prev_state"] = event.replaces_state
self._simple_insert_txn(txn, "state_events", vals)
@@ -258,6 +258,40 @@ class DataStore(RoomMemberStore, RoomStore,
}
)
+ for e_id, h in event.prev_state:
+ self._simple_insert_txn(
+ txn,
+ table="event_edges",
+ values={
+ "event_id": event.event_id,
+ "prev_event_id": e_id,
+ "room_id": event.room_id,
+ "is_state": 1,
+ },
+ or_ignore=True,
+ )
+
+ if not backfilled:
+ self._simple_insert_txn(
+ txn,
+ table="state_forward_extremities",
+ values={
+ "event_id": event.event_id,
+ "room_id": event.room_id,
+ "type": event.type,
+ "state_key": event.state_key,
+ }
+ )
+
+ for prev_state_id, _ in event.prev_state:
+ self._simple_delete_txn(
+ txn,
+ table="state_forward_extremities",
+ keyvalues={
+ "event_id": prev_state_id,
+ }
+ )
+
for hash_alg, hash_base64 in event.hashes.items():
hash_bytes = decode_base64(hash_base64)
self._store_event_content_hash_txn(
@@ -357,7 +391,7 @@ class DataStore(RoomMemberStore, RoomStore,
],
)
- def snapshot_room(self, room_id, user_id, state_type=None, state_key=None):
+ def snapshot_room(self, event):
"""Snapshot the room for an update by a user
Args:
room_id (synapse.types.RoomId): The room to snapshot.
@@ -368,16 +402,29 @@ class DataStore(RoomMemberStore, RoomStore,
synapse.storage.Snapshot: A snapshot of the state of the room.
"""
def _snapshot(txn):
- membership_state = self._get_room_member(txn, user_id, room_id)
- prev_events = self._get_latest_events_in_room(txn, room_id)
+ prev_events = self._get_latest_events_in_room(
+ txn,
+ event.room_id
+ )
+
+ prev_state = None
+ state_key = None
+ if hasattr(event, "state_key"):
+ state_key = event.state_key
+ prev_state = self._get_latest_state_in_room(
+ txn,
+ event.room_id,
+ type=event.type,
+ state_key=state_key,
+ )
return Snapshot(
store=self,
- room_id=room_id,
- user_id=user_id,
+ room_id=event.room_id,
+ user_id=event.user_id,
prev_events=prev_events,
- membership_state=membership_state,
- state_type=state_type,
+ prev_state=prev_state,
+ state_type=event.type,
state_key=state_key,
)
@@ -400,30 +447,29 @@ class Snapshot(object):
"""
def __init__(self, store, room_id, user_id, prev_events,
- membership_state, state_type=None, state_key=None,
- prev_state_pdu=None):
+ prev_state, state_type=None, state_key=None):
self.store = store
self.room_id = room_id
self.user_id = user_id
self.prev_events = prev_events
- self.membership_state = membership_state
+ self.prev_state = prev_state
self.state_type = state_type
self.state_key = state_key
- self.prev_state_pdu = prev_state_pdu
def fill_out_prev_events(self, event):
- if hasattr(event, "prev_events"):
- return
+ if not hasattr(event, "prev_events"):
+ event.prev_events = [
+ (event_id, hashes)
+ for event_id, hashes, _ in self.prev_events
+ ]
- event.prev_events = [
- (event_id, hashes)
- for event_id, hashes, _ in self.prev_events
- ]
+ if self.prev_events:
+ event.depth = max([int(v) for _, _, v in self.prev_events]) + 1
+ else:
+ event.depth = 0
- if self.prev_events:
- event.depth = max([int(v) for _, _, v in self.prev_events]) + 1
- else:
- event.depth = 0
+ if not hasattr(event, "prev_state") and self.prev_state is not None:
+ event.prev_state = self.prev_state
def schema_path(schema):
|