summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/__init__.py92
-rw-r--r--synapse/storage/_base.py66
-rw-r--r--synapse/storage/event_federation.py64
-rw-r--r--synapse/storage/schema/event_edges.sql40
4 files changed, 203 insertions, 59 deletions
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 6b8fed4502..2d62fc2ed0 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -242,8 +242,8 @@ class DataStore(RoomMemberStore, RoomStore,
                 "state_key": event.state_key,
             }
 
-            if hasattr(event, "prev_state"):
-                vals["prev_state"] = event.prev_state
+            if hasattr(event, "replaces_state"):
+                vals["prev_state"] = event.replaces_state
 
             self._simple_insert_txn(txn, "state_events", vals)
 
@@ -258,6 +258,40 @@ class DataStore(RoomMemberStore, RoomStore,
                 }
             )
 
+            for e_id, h in event.prev_state:
+                self._simple_insert_txn(
+                    txn,
+                    table="event_edges",
+                    values={
+                        "event_id": event.event_id,
+                        "prev_event_id": e_id,
+                        "room_id": event.room_id,
+                        "is_state": 1,
+                    },
+                    or_ignore=True,
+                )
+
+            if not backfilled:
+                self._simple_insert_txn(
+                    txn,
+                    table="state_forward_extremities",
+                    values={
+                        "event_id": event.event_id,
+                        "room_id": event.room_id,
+                        "type": event.type,
+                        "state_key": event.state_key,
+                    }
+                )
+
+                for prev_state_id, _ in event.prev_state:
+                    self._simple_delete_txn(
+                        txn,
+                        table="state_forward_extremities",
+                        keyvalues={
+                            "event_id": prev_state_id,
+                        }
+                    )
+
         for hash_alg, hash_base64 in event.hashes.items():
             hash_bytes = decode_base64(hash_base64)
             self._store_event_content_hash_txn(
@@ -357,7 +391,7 @@ class DataStore(RoomMemberStore, RoomStore,
             ],
         )
 
-    def snapshot_room(self, room_id, user_id, state_type=None, state_key=None):
+    def snapshot_room(self, event):
         """Snapshot the room for an update by a user
         Args:
             room_id (synapse.types.RoomId): The room to snapshot.
@@ -368,16 +402,29 @@ class DataStore(RoomMemberStore, RoomStore,
             synapse.storage.Snapshot: A snapshot of the state of the room.
         """
         def _snapshot(txn):
-            membership_state = self._get_room_member(txn, user_id, room_id)
-            prev_events = self._get_latest_events_in_room(txn, room_id)
+            prev_events = self._get_latest_events_in_room(
+                txn,
+                event.room_id
+            )
+
+            prev_state = None
+            state_key = None
+            if hasattr(event, "state_key"):
+                state_key = event.state_key
+                prev_state = self._get_latest_state_in_room(
+                    txn,
+                    event.room_id,
+                    type=event.type,
+                    state_key=state_key,
+                )
 
             return Snapshot(
                 store=self,
-                room_id=room_id,
-                user_id=user_id,
+                room_id=event.room_id,
+                user_id=event.user_id,
                 prev_events=prev_events,
-                membership_state=membership_state,
-                state_type=state_type,
+                prev_state=prev_state,
+                state_type=event.type,
                 state_key=state_key,
             )
 
@@ -400,30 +447,29 @@ class Snapshot(object):
     """
 
     def __init__(self, store, room_id, user_id, prev_events,
-                 membership_state, state_type=None, state_key=None,
-                 prev_state_pdu=None):
+                 prev_state, state_type=None, state_key=None):
         self.store = store
         self.room_id = room_id
         self.user_id = user_id
         self.prev_events = prev_events
-        self.membership_state = membership_state
+        self.prev_state = prev_state
         self.state_type = state_type
         self.state_key = state_key
-        self.prev_state_pdu = prev_state_pdu
 
     def fill_out_prev_events(self, event):
-        if hasattr(event, "prev_events"):
-            return
+        if not hasattr(event, "prev_events"):
+            event.prev_events = [
+                (event_id, hashes)
+                for event_id, hashes, _ in self.prev_events
+            ]
 
-        event.prev_events = [
-            (event_id, hashes)
-            for event_id, hashes, _ in self.prev_events
-        ]
+            if self.prev_events:
+                event.depth = max([int(v) for _, _, v in self.prev_events]) + 1
+            else:
+                event.depth = 0
 
-        if self.prev_events:
-            event.depth = max([int(v) for _, _, v in self.prev_events]) + 1
-        else:
-            event.depth = 0
+        if not hasattr(event, "prev_state") and self.prev_state is not None:
+            event.prev_state = self.prev_state
 
 
 def schema_path(schema):
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 7d445b4633..7821fc4726 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -245,7 +245,6 @@ class SQLBaseStore(object):
 
         return [r[0] for r in txn.fetchall()]
 
-
     def _simple_select_onecol(self, table, keyvalues, retcol):
         """Executes a SELECT query on the named table, which returns a list
         comprising of the values of the named column from the selected rows.
@@ -273,17 +272,30 @@ class SQLBaseStore(object):
             keyvalues : dict of column names and values to select the rows with
             retcols : list of strings giving the names of the columns to return
         """
+        return self.runInteraction(
+            "_simple_select_list",
+            self._simple_select_list_txn,
+            table, keyvalues, retcols
+        )
+
+    def _simple_select_list_txn(self, txn, table, keyvalues, retcols):
+        """Executes a SELECT query on the named table, which may return zero or
+        more rows, returning the result as a list of dicts.
+
+        Args:
+            txn : Transaction object
+            table : string giving the table name
+            keyvalues : dict of column names and values to select the rows with
+            retcols : list of strings giving the names of the columns to return
+        """
         sql = "SELECT %s FROM %s WHERE %s" % (
             ", ".join(retcols),
             table,
-            " AND ".join("%s = ?" % (k) for k in keyvalues)
+            " AND ".join("%s = ?" % (k, ) for k in keyvalues)
         )
 
-        def func(txn):
-            txn.execute(sql, keyvalues.values())
-            return self.cursor_to_dict(txn)
-
-        return self.runInteraction("_simple_select_list", func)
+        txn.execute(sql, keyvalues.values())
+        return self.cursor_to_dict(txn)
 
     def _simple_update_one(self, table, keyvalues, updatevalues,
                            retcols=None):
@@ -417,6 +429,10 @@ class SQLBaseStore(object):
         d.pop("topological_ordering", None)
         d.pop("processed", None)
         d["origin_server_ts"] = d.pop("ts", 0)
+        replaces_state = d.pop("prev_state", None)
+
+        if replaces_state:
+            d["replaces_state"] = replaces_state
 
         d.update(json.loads(row_dict["unrecognized_keys"]))
         d["content"] = json.loads(d["content"])
@@ -450,16 +466,32 @@ class SQLBaseStore(object):
                 k: encode_base64(v) for k, v in signatures.items()
             }
 
-            ev.prev_events = self._get_prev_events(txn, ev.event_id)
-
-            if hasattr(ev, "prev_state"):
-                # Load previous state_content.
-                # TODO: Should we be pulling this out above?
-                cursor = txn.execute(select_event_sql, (ev.prev_state,))
-                prevs = self.cursor_to_dict(cursor)
-                if prevs:
-                    prev = self._parse_event_from_row(prevs[0])
-                    ev.prev_content = prev.content
+            prevs = self._get_prev_events_and_state(txn, ev.event_id)
+
+            ev.prev_events = [
+                (e_id, h)
+                for e_id, h, is_state in prevs
+                if is_state == 0
+            ]
+
+            if hasattr(ev, "state_key"):
+                ev.prev_state = [
+                    (e_id, h)
+                    for e_id, h, is_state in prevs
+                    if is_state == 1
+                ]
+
+                if hasattr(ev, "replaces_state"):
+                    # Load previous state_content.
+                    # FIXME (erikj): Handle multiple prev_states.
+                    cursor = txn.execute(
+                        select_event_sql,
+                        (ev.replaces_state,)
+                    )
+                    prevs = self.cursor_to_dict(cursor)
+                    if prevs:
+                        prev = self._parse_event_from_row(prevs[0])
+                        ev.prev_content = prev.content
 
             if not hasattr(ev, "redacted"):
                 logger.debug("Doesn't have redacted key: %s", ev)
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index f427aba879..180a764134 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -69,19 +69,21 @@ class EventFederationStore(SQLBaseStore):
 
         return results
 
-    def _get_prev_events(self, txn, event_id):
-        prev_ids = self._simple_select_onecol_txn(
+    def _get_latest_state_in_room(self, txn, room_id, type, state_key):
+        event_ids = self._simple_select_onecol_txn(
             txn,
-            table="event_edges",
+            table="state_forward_extremities",
             keyvalues={
-                "event_id": event_id,
+                "room_id": room_id,
+                "type": type,
+                "state_key": state_key,
             },
-            retcol="prev_event_id",
+            retcol="event_id",
         )
 
         results = []
-        for prev_event_id in prev_ids:
-            hashes = self._get_event_reference_hashes_txn(txn, prev_event_id)
+        for event_id in event_ids:
+            hashes = self._get_event_reference_hashes_txn(txn, event_id)
             prev_hashes = {
                 k: encode_base64(v) for k, v in hashes.items()
                 if k == "sha256"
@@ -90,6 +92,53 @@ class EventFederationStore(SQLBaseStore):
 
         return results
 
+    def _get_prev_events(self, txn, event_id):
+        results = self._get_prev_events_and_state(
+            txn,
+            event_id,
+            is_state=0,
+        )
+
+        return [(e_id, h, ) for e_id, h, _ in results]
+
+    def _get_prev_state(self, txn, event_id):
+        results = self._get_prev_events_and_state(
+            txn,
+            event_id,
+            is_state=1,
+        )
+
+        return [(e_id, h, ) for e_id, h, _ in results]
+
+    def _get_prev_events_and_state(self, txn, event_id, is_state=None):
+        keyvalues = {
+            "event_id": event_id,
+        }
+
+        if is_state is not None:
+            keyvalues["is_state"] = is_state
+
+        res = self._simple_select_list_txn(
+            txn,
+            table="event_edges",
+            keyvalues=keyvalues,
+            retcols=["prev_event_id", "is_state"],
+        )
+
+        results = []
+        for d in res:
+            hashes = self._get_event_reference_hashes_txn(
+                txn,
+                d["prev_event_id"]
+            )
+            prev_hashes = {
+                k: encode_base64(v) for k, v in hashes.items()
+                if k == "sha256"
+            }
+            results.append((d["prev_event_id"], prev_hashes, d["is_state"]))
+
+        return results
+
     def get_min_depth(self, room_id):
         return self.runInteraction(
             "get_min_depth",
@@ -135,6 +184,7 @@ class EventFederationStore(SQLBaseStore):
                     "event_id": event_id,
                     "prev_event_id": e_id,
                     "room_id": room_id,
+                    "is_state": 0,
                 },
                 or_ignore=True,
             )
diff --git a/synapse/storage/schema/event_edges.sql b/synapse/storage/schema/event_edges.sql
index e5f768c705..51695826a8 100644
--- a/synapse/storage/schema/event_edges.sql
+++ b/synapse/storage/schema/event_edges.sql
@@ -1,7 +1,7 @@
 
 CREATE TABLE IF NOT EXISTS event_forward_extremities(
-    event_id TEXT,
-    room_id TEXT,
+    event_id TEXT NOT NULL,
+    room_id TEXT NOT NULL,
     CONSTRAINT uniqueness UNIQUE (event_id, room_id) ON CONFLICT REPLACE
 );
 
@@ -10,8 +10,8 @@ CREATE INDEX IF NOT EXISTS ev_extrem_id ON event_forward_extremities(event_id);
 
 
 CREATE TABLE IF NOT EXISTS event_backward_extremities(
-    event_id TEXT,
-    room_id TEXT,
+    event_id TEXT NOT NULL,
+    room_id TEXT NOT NULL,
     CONSTRAINT uniqueness UNIQUE (event_id, room_id) ON CONFLICT REPLACE
 );
 
@@ -20,10 +20,11 @@ CREATE INDEX IF NOT EXISTS ev_b_extrem_id ON event_backward_extremities(event_id
 
 
 CREATE TABLE IF NOT EXISTS event_edges(
-    event_id TEXT,
-    prev_event_id TEXT,
-    room_id TEXT,
-    CONSTRAINT uniqueness UNIQUE (event_id, prev_event_id, room_id)
+    event_id TEXT NOT NULL,
+    prev_event_id TEXT NOT NULL,
+    room_id TEXT NOT NULL,
+    is_state INTEGER NOT NULL,
+    CONSTRAINT uniqueness UNIQUE (event_id, prev_event_id, room_id, is_state)
 );
 
 CREATE INDEX IF NOT EXISTS ev_edges_id ON event_edges(event_id);
@@ -31,8 +32,8 @@ CREATE INDEX IF NOT EXISTS ev_edges_prev_id ON event_edges(prev_event_id);
 
 
 CREATE TABLE IF NOT EXISTS room_depth(
-    room_id TEXT,
-    min_depth INTEGER,
+    room_id TEXT NOT NULL,
+    min_depth INTEGER NOT NULL,
     CONSTRAINT uniqueness UNIQUE (room_id)
 );
 
@@ -40,10 +41,25 @@ CREATE INDEX IF NOT EXISTS room_depth_room ON room_depth(room_id);
 
 
 create TABLE IF NOT EXISTS event_destinations(
-    event_id TEXT,
-    destination TEXT,
+    event_id TEXT NOT NULL,
+    destination TEXT NOT NULL,
     delivered_ts INTEGER DEFAULT 0, -- or 0 if not delivered
     CONSTRAINT uniqueness UNIQUE (event_id, destination) ON CONFLICT REPLACE
 );
 
 CREATE INDEX IF NOT EXISTS event_destinations_id ON event_destinations(event_id);
+
+
+CREATE TABLE IF NOT EXISTS state_forward_extremities(
+    event_id TEXT NOT NULL,
+    room_id TEXT NOT NULL,
+    type TEXT NOT NULL,
+    state_key TEXT NOT NULL,
+    CONSTRAINT uniqueness UNIQUE (event_id, room_id) ON CONFLICT REPLACE
+);
+
+CREATE INDEX IF NOT EXISTS st_extrem_keys ON state_forward_extremities(
+    room_id, type, state_key
+);
+CREATE INDEX IF NOT EXISTS st_extrem_id ON state_forward_extremities(event_id);
+