summary refs log tree commit diff
path: root/synapse/storage/event_federation.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/event_federation.py')
-rw-r--r--synapse/storage/event_federation.py64
1 files changed, 57 insertions, 7 deletions
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index f427aba879..180a764134 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -69,19 +69,21 @@ class EventFederationStore(SQLBaseStore):
 
         return results
 
-    def _get_prev_events(self, txn, event_id):
-        prev_ids = self._simple_select_onecol_txn(
+    def _get_latest_state_in_room(self, txn, room_id, type, state_key):
+        event_ids = self._simple_select_onecol_txn(
             txn,
-            table="event_edges",
+            table="state_forward_extremities",
             keyvalues={
-                "event_id": event_id,
+                "room_id": room_id,
+                "type": type,
+                "state_key": state_key,
             },
-            retcol="prev_event_id",
+            retcol="event_id",
         )
 
         results = []
-        for prev_event_id in prev_ids:
-            hashes = self._get_event_reference_hashes_txn(txn, prev_event_id)
+        for event_id in event_ids:
+            hashes = self._get_event_reference_hashes_txn(txn, event_id)
             prev_hashes = {
                 k: encode_base64(v) for k, v in hashes.items()
                 if k == "sha256"
@@ -90,6 +92,53 @@ class EventFederationStore(SQLBaseStore):
 
         return results
 
+    def _get_prev_events(self, txn, event_id):
+        results = self._get_prev_events_and_state(
+            txn,
+            event_id,
+            is_state=0,
+        )
+
+        return [(e_id, h, ) for e_id, h, _ in results]
+
+    def _get_prev_state(self, txn, event_id):
+        results = self._get_prev_events_and_state(
+            txn,
+            event_id,
+            is_state=1,
+        )
+
+        return [(e_id, h, ) for e_id, h, _ in results]
+
+    def _get_prev_events_and_state(self, txn, event_id, is_state=None):
+        keyvalues = {
+            "event_id": event_id,
+        }
+
+        if is_state is not None:
+            keyvalues["is_state"] = is_state
+
+        res = self._simple_select_list_txn(
+            txn,
+            table="event_edges",
+            keyvalues=keyvalues,
+            retcols=["prev_event_id", "is_state"],
+        )
+
+        results = []
+        for d in res:
+            hashes = self._get_event_reference_hashes_txn(
+                txn,
+                d["prev_event_id"]
+            )
+            prev_hashes = {
+                k: encode_base64(v) for k, v in hashes.items()
+                if k == "sha256"
+            }
+            results.append((d["prev_event_id"], prev_hashes, d["is_state"]))
+
+        return results
+
     def get_min_depth(self, room_id):
         return self.runInteraction(
             "get_min_depth",
@@ -135,6 +184,7 @@ class EventFederationStore(SQLBaseStore):
                     "event_id": event_id,
                     "prev_event_id": e_id,
                     "room_id": room_id,
+                    "is_state": 0,
                 },
                 or_ignore=True,
             )