summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/api/events/__init__.py1
-rw-r--r--synapse/handlers/directory.py5
-rw-r--r--synapse/handlers/federation.py4
-rw-r--r--synapse/handlers/message.py11
-rw-r--r--synapse/handlers/profile.py6
-rw-r--r--synapse/handlers/room.py16
-rw-r--r--synapse/rest/room.py2
-rw-r--r--synapse/state.py39
-rw-r--r--synapse/storage/__init__.py92
-rw-r--r--synapse/storage/_base.py66
-rw-r--r--synapse/storage/event_federation.py64
-rw-r--r--synapse/storage/schema/event_edges.sql40
-rw-r--r--synapse/util/jsonobject.py2
13 files changed, 220 insertions, 128 deletions
diff --git a/synapse/api/events/__init__.py b/synapse/api/events/__init__.py
index 168b812311..fc3f350570 100644
--- a/synapse/api/events/__init__.py
+++ b/synapse/api/events/__init__.py
@@ -60,6 +60,7 @@ class SynapseEvent(JsonEncodedObject):
         "age_ts",
         "prev_content",
         "prev_state",
+        "replaces_state",
         "redacted_because",
         "origin_server_ts",
     ]
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 6e897e915d..164363cdc5 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -147,10 +147,7 @@ class DirectoryHandler(BaseHandler):
             content={"aliases": aliases},
         )
 
-        snapshot = yield self.store.snapshot_room(
-            room_id=room_id,
-            user_id=user_id,
-        )
+        snapshot = yield self.store.snapshot_room(event)
 
         yield self._on_new_room_event(
             event, snapshot, extra_users=[user_id], suppress_auth=True
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 1464a60937..513ec9a5e3 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -313,9 +313,7 @@ class FederationHandler(BaseHandler):
             state_key=user_id,
         )
 
-        snapshot = yield self.store.snapshot_room(
-            event.room_id, event.user_id,
-        )
+        snapshot = yield self.store.snapshot_room(event)
         snapshot.fill_out_prev_events(event)
 
         yield self.state_handler.annotate_state_groups(event)
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index c6f6ab14d1..8394013df3 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -81,7 +81,7 @@ class MessageHandler(BaseHandler):
         user = self.hs.parse_userid(event.user_id)
         assert user.is_mine, "User must be our own: %s" % (user,)
 
-        snapshot = yield self.store.snapshot_room(event.room_id, event.user_id)
+        snapshot = yield self.store.snapshot_room(event)
 
         yield self._on_new_room_event(
             event, snapshot, suppress_auth=suppress_auth
@@ -141,12 +141,7 @@ class MessageHandler(BaseHandler):
             SynapseError if something went wrong.
         """
 
-        snapshot = yield self.store.snapshot_room(
-            event.room_id,
-            event.user_id,
-            state_type=event.type,
-            state_key=event.state_key,
-        )
+        snapshot = yield self.store.snapshot_room(event)
 
         yield self._on_new_room_event(event, snapshot)
 
@@ -214,7 +209,7 @@ class MessageHandler(BaseHandler):
 
     @defer.inlineCallbacks
     def send_feedback(self, event):
-        snapshot = yield self.store.snapshot_room(event.room_id, event.user_id)
+        snapshot = yield self.store.snapshot_room(event)
 
         # store message in db
         yield self._on_new_room_event(event, snapshot)
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 4cd0a06093..e47814483a 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -17,7 +17,6 @@ from twisted.internet import defer
 
 from synapse.api.errors import SynapseError, AuthError, CodeMessageException
 from synapse.api.constants import Membership
-from synapse.api.events.room import RoomMemberEvent
 
 from ._base import BaseHandler
 
@@ -196,10 +195,7 @@ class ProfileHandler(BaseHandler):
         )
 
         for j in joins:
-            snapshot = yield self.store.snapshot_room(
-                j.room_id, j.state_key, RoomMemberEvent.TYPE,
-                j.state_key
-            )
+            snapshot = yield self.store.snapshot_room(j)
 
             content = {
                 "membership": j.content["membership"],
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index f176ad39bf..55c893eb58 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -122,10 +122,7 @@ class RoomCreationHandler(BaseHandler):
 
         @defer.inlineCallbacks
         def handle_event(event):
-            snapshot = yield self.store.snapshot_room(
-                room_id=room_id,
-                user_id=user_id,
-            )
+            snapshot = yield self.store.snapshot_room(event)
 
             logger.debug("Event: %s", event)
 
@@ -364,10 +361,8 @@ class RoomMemberHandler(BaseHandler):
         """
         target_user_id = event.state_key
 
-        snapshot = yield self.store.snapshot_room(
-            event.room_id, event.user_id,
-            RoomMemberEvent.TYPE, target_user_id
-        )
+        snapshot = yield self.store.snapshot_room(event)
+
         ## TODO(markjh): get prev state from snapshot.
         prev_state = yield self.store.get_room_member(
             target_user_id, event.room_id
@@ -442,10 +437,7 @@ class RoomMemberHandler(BaseHandler):
             content=content,
         )
 
-        snapshot = yield self.store.snapshot_room(
-            room_id, joinee.to_string(), RoomMemberEvent.TYPE,
-            joinee.to_string()
-        )
+        snapshot = yield self.store.snapshot_room(new_event)
 
         yield self._do_join(new_event, snapshot, room_host=host, do_auth=True)
 
diff --git a/synapse/rest/room.py b/synapse/rest/room.py
index ec0ce78fda..997895dab0 100644
--- a/synapse/rest/room.py
+++ b/synapse/rest/room.py
@@ -138,7 +138,7 @@ class RoomStateEventRestServlet(RestServlet):
             raise SynapseError(
                 404, "Event not found.", errcode=Codes.NOT_FOUND
             )
-        defer.returnValue((200, data[0].get_dict()["content"]))
+        defer.returnValue((200, data.get_dict()["content"]))
 
     @defer.inlineCallbacks
     def on_PUT(self, request, room_id, event_type, state_key):
diff --git a/synapse/state.py b/synapse/state.py
index 32744e047c..97a8160a33 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -47,40 +47,6 @@ class StateHandler(object):
 
     @defer.inlineCallbacks
     @log_function
-    def handle_new_event(self, event, snapshot):
-        """ Given an event this works out if a) we have sufficient power level
-        to update the state and b) works out what the prev_state should be.
-
-        Returns:
-            Deferred: Resolved with a boolean indicating if we successfully
-            updated the state.
-
-        Raised:
-            AuthError
-        """
-        # This needs to be done in a transaction.
-
-        if not hasattr(event, "state_key"):
-            return
-
-        # Now I need to fill out the prev state and work out if it has auth
-        # (w.r.t. to power levels)
-
-        snapshot.fill_out_prev_events(event)
-        yield self.annotate_state_groups(event)
-
-        if event.old_state_events:
-            current_state = event.old_state_events.get(
-                (event.type, event.state_key)
-            )
-
-            if current_state:
-                event.prev_state = current_state.event_id
-
-        defer.returnValue(True)
-
-    @defer.inlineCallbacks
-    @log_function
     def annotate_state_groups(self, event, old_state=None):
         yield run_on_reactor()
 
@@ -111,7 +77,10 @@ class StateHandler(object):
         event.old_state_events = copy.deepcopy(new_state)
 
         if hasattr(event, "state_key"):
-            new_state[(event.type, event.state_key)] = event
+            key = (event.type, event.state_key)
+            if key in new_state:
+                event.replaces_state = new_state[key].event_id
+            new_state[key] = event
 
         event.state_group = None
         event.state_events = new_state
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 6b8fed4502..2d62fc2ed0 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -242,8 +242,8 @@ class DataStore(RoomMemberStore, RoomStore,
                 "state_key": event.state_key,
             }
 
-            if hasattr(event, "prev_state"):
-                vals["prev_state"] = event.prev_state
+            if hasattr(event, "replaces_state"):
+                vals["prev_state"] = event.replaces_state
 
             self._simple_insert_txn(txn, "state_events", vals)
 
@@ -258,6 +258,40 @@ class DataStore(RoomMemberStore, RoomStore,
                 }
             )
 
+            for e_id, h in event.prev_state:
+                self._simple_insert_txn(
+                    txn,
+                    table="event_edges",
+                    values={
+                        "event_id": event.event_id,
+                        "prev_event_id": e_id,
+                        "room_id": event.room_id,
+                        "is_state": 1,
+                    },
+                    or_ignore=True,
+                )
+
+            if not backfilled:
+                self._simple_insert_txn(
+                    txn,
+                    table="state_forward_extremities",
+                    values={
+                        "event_id": event.event_id,
+                        "room_id": event.room_id,
+                        "type": event.type,
+                        "state_key": event.state_key,
+                    }
+                )
+
+                for prev_state_id, _ in event.prev_state:
+                    self._simple_delete_txn(
+                        txn,
+                        table="state_forward_extremities",
+                        keyvalues={
+                            "event_id": prev_state_id,
+                        }
+                    )
+
         for hash_alg, hash_base64 in event.hashes.items():
             hash_bytes = decode_base64(hash_base64)
             self._store_event_content_hash_txn(
@@ -357,7 +391,7 @@ class DataStore(RoomMemberStore, RoomStore,
             ],
         )
 
-    def snapshot_room(self, room_id, user_id, state_type=None, state_key=None):
+    def snapshot_room(self, event):
         """Snapshot the room for an update by a user
         Args:
             room_id (synapse.types.RoomId): The room to snapshot.
@@ -368,16 +402,29 @@ class DataStore(RoomMemberStore, RoomStore,
             synapse.storage.Snapshot: A snapshot of the state of the room.
         """
         def _snapshot(txn):
-            membership_state = self._get_room_member(txn, user_id, room_id)
-            prev_events = self._get_latest_events_in_room(txn, room_id)
+            prev_events = self._get_latest_events_in_room(
+                txn,
+                event.room_id
+            )
+
+            prev_state = None
+            state_key = None
+            if hasattr(event, "state_key"):
+                state_key = event.state_key
+                prev_state = self._get_latest_state_in_room(
+                    txn,
+                    event.room_id,
+                    type=event.type,
+                    state_key=state_key,
+                )
 
             return Snapshot(
                 store=self,
-                room_id=room_id,
-                user_id=user_id,
+                room_id=event.room_id,
+                user_id=event.user_id,
                 prev_events=prev_events,
-                membership_state=membership_state,
-                state_type=state_type,
+                prev_state=prev_state,
+                state_type=event.type,
                 state_key=state_key,
             )
 
@@ -400,30 +447,29 @@ class Snapshot(object):
     """
 
     def __init__(self, store, room_id, user_id, prev_events,
-                 membership_state, state_type=None, state_key=None,
-                 prev_state_pdu=None):
+                 prev_state, state_type=None, state_key=None):
         self.store = store
         self.room_id = room_id
         self.user_id = user_id
         self.prev_events = prev_events
-        self.membership_state = membership_state
+        self.prev_state = prev_state
         self.state_type = state_type
         self.state_key = state_key
-        self.prev_state_pdu = prev_state_pdu
 
     def fill_out_prev_events(self, event):
-        if hasattr(event, "prev_events"):
-            return
+        if not hasattr(event, "prev_events"):
+            event.prev_events = [
+                (event_id, hashes)
+                for event_id, hashes, _ in self.prev_events
+            ]
 
-        event.prev_events = [
-            (event_id, hashes)
-            for event_id, hashes, _ in self.prev_events
-        ]
+            if self.prev_events:
+                event.depth = max([int(v) for _, _, v in self.prev_events]) + 1
+            else:
+                event.depth = 0
 
-        if self.prev_events:
-            event.depth = max([int(v) for _, _, v in self.prev_events]) + 1
-        else:
-            event.depth = 0
+        if not hasattr(event, "prev_state") and self.prev_state is not None:
+            event.prev_state = self.prev_state
 
 
 def schema_path(schema):
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 7d445b4633..7821fc4726 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -245,7 +245,6 @@ class SQLBaseStore(object):
 
         return [r[0] for r in txn.fetchall()]
 
-
     def _simple_select_onecol(self, table, keyvalues, retcol):
         """Executes a SELECT query on the named table, which returns a list
         comprising of the values of the named column from the selected rows.
@@ -273,17 +272,30 @@ class SQLBaseStore(object):
             keyvalues : dict of column names and values to select the rows with
             retcols : list of strings giving the names of the columns to return
         """
+        return self.runInteraction(
+            "_simple_select_list",
+            self._simple_select_list_txn,
+            table, keyvalues, retcols
+        )
+
+    def _simple_select_list_txn(self, txn, table, keyvalues, retcols):
+        """Executes a SELECT query on the named table, which may return zero or
+        more rows, returning the result as a list of dicts.
+
+        Args:
+            txn : Transaction object
+            table : string giving the table name
+            keyvalues : dict of column names and values to select the rows with
+            retcols : list of strings giving the names of the columns to return
+        """
         sql = "SELECT %s FROM %s WHERE %s" % (
             ", ".join(retcols),
             table,
-            " AND ".join("%s = ?" % (k) for k in keyvalues)
+            " AND ".join("%s = ?" % (k, ) for k in keyvalues)
         )
 
-        def func(txn):
-            txn.execute(sql, keyvalues.values())
-            return self.cursor_to_dict(txn)
-
-        return self.runInteraction("_simple_select_list", func)
+        txn.execute(sql, keyvalues.values())
+        return self.cursor_to_dict(txn)
 
     def _simple_update_one(self, table, keyvalues, updatevalues,
                            retcols=None):
@@ -417,6 +429,10 @@ class SQLBaseStore(object):
         d.pop("topological_ordering", None)
         d.pop("processed", None)
         d["origin_server_ts"] = d.pop("ts", 0)
+        replaces_state = d.pop("prev_state", None)
+
+        if replaces_state:
+            d["replaces_state"] = replaces_state
 
         d.update(json.loads(row_dict["unrecognized_keys"]))
         d["content"] = json.loads(d["content"])
@@ -450,16 +466,32 @@ class SQLBaseStore(object):
                 k: encode_base64(v) for k, v in signatures.items()
             }
 
-            ev.prev_events = self._get_prev_events(txn, ev.event_id)
-
-            if hasattr(ev, "prev_state"):
-                # Load previous state_content.
-                # TODO: Should we be pulling this out above?
-                cursor = txn.execute(select_event_sql, (ev.prev_state,))
-                prevs = self.cursor_to_dict(cursor)
-                if prevs:
-                    prev = self._parse_event_from_row(prevs[0])
-                    ev.prev_content = prev.content
+            prevs = self._get_prev_events_and_state(txn, ev.event_id)
+
+            ev.prev_events = [
+                (e_id, h)
+                for e_id, h, is_state in prevs
+                if is_state == 0
+            ]
+
+            if hasattr(ev, "state_key"):
+                ev.prev_state = [
+                    (e_id, h)
+                    for e_id, h, is_state in prevs
+                    if is_state == 1
+                ]
+
+                if hasattr(ev, "replaces_state"):
+                    # Load previous state_content.
+                    # FIXME (erikj): Handle multiple prev_states.
+                    cursor = txn.execute(
+                        select_event_sql,
+                        (ev.replaces_state,)
+                    )
+                    prevs = self.cursor_to_dict(cursor)
+                    if prevs:
+                        prev = self._parse_event_from_row(prevs[0])
+                        ev.prev_content = prev.content
 
             if not hasattr(ev, "redacted"):
                 logger.debug("Doesn't have redacted key: %s", ev)
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index f427aba879..180a764134 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -69,19 +69,21 @@ class EventFederationStore(SQLBaseStore):
 
         return results
 
-    def _get_prev_events(self, txn, event_id):
-        prev_ids = self._simple_select_onecol_txn(
+    def _get_latest_state_in_room(self, txn, room_id, type, state_key):
+        event_ids = self._simple_select_onecol_txn(
             txn,
-            table="event_edges",
+            table="state_forward_extremities",
             keyvalues={
-                "event_id": event_id,
+                "room_id": room_id,
+                "type": type,
+                "state_key": state_key,
             },
-            retcol="prev_event_id",
+            retcol="event_id",
         )
 
         results = []
-        for prev_event_id in prev_ids:
-            hashes = self._get_event_reference_hashes_txn(txn, prev_event_id)
+        for event_id in event_ids:
+            hashes = self._get_event_reference_hashes_txn(txn, event_id)
             prev_hashes = {
                 k: encode_base64(v) for k, v in hashes.items()
                 if k == "sha256"
@@ -90,6 +92,53 @@ class EventFederationStore(SQLBaseStore):
 
         return results
 
+    def _get_prev_events(self, txn, event_id):
+        results = self._get_prev_events_and_state(
+            txn,
+            event_id,
+            is_state=0,
+        )
+
+        return [(e_id, h, ) for e_id, h, _ in results]
+
+    def _get_prev_state(self, txn, event_id):
+        results = self._get_prev_events_and_state(
+            txn,
+            event_id,
+            is_state=1,
+        )
+
+        return [(e_id, h, ) for e_id, h, _ in results]
+
+    def _get_prev_events_and_state(self, txn, event_id, is_state=None):
+        keyvalues = {
+            "event_id": event_id,
+        }
+
+        if is_state is not None:
+            keyvalues["is_state"] = is_state
+
+        res = self._simple_select_list_txn(
+            txn,
+            table="event_edges",
+            keyvalues=keyvalues,
+            retcols=["prev_event_id", "is_state"],
+        )
+
+        results = []
+        for d in res:
+            hashes = self._get_event_reference_hashes_txn(
+                txn,
+                d["prev_event_id"]
+            )
+            prev_hashes = {
+                k: encode_base64(v) for k, v in hashes.items()
+                if k == "sha256"
+            }
+            results.append((d["prev_event_id"], prev_hashes, d["is_state"]))
+
+        return results
+
     def get_min_depth(self, room_id):
         return self.runInteraction(
             "get_min_depth",
@@ -135,6 +184,7 @@ class EventFederationStore(SQLBaseStore):
                     "event_id": event_id,
                     "prev_event_id": e_id,
                     "room_id": room_id,
+                    "is_state": 0,
                 },
                 or_ignore=True,
             )
diff --git a/synapse/storage/schema/event_edges.sql b/synapse/storage/schema/event_edges.sql
index e5f768c705..51695826a8 100644
--- a/synapse/storage/schema/event_edges.sql
+++ b/synapse/storage/schema/event_edges.sql
@@ -1,7 +1,7 @@
 
 CREATE TABLE IF NOT EXISTS event_forward_extremities(
-    event_id TEXT,
-    room_id TEXT,
+    event_id TEXT NOT NULL,
+    room_id TEXT NOT NULL,
     CONSTRAINT uniqueness UNIQUE (event_id, room_id) ON CONFLICT REPLACE
 );
 
@@ -10,8 +10,8 @@ CREATE INDEX IF NOT EXISTS ev_extrem_id ON event_forward_extremities(event_id);
 
 
 CREATE TABLE IF NOT EXISTS event_backward_extremities(
-    event_id TEXT,
-    room_id TEXT,
+    event_id TEXT NOT NULL,
+    room_id TEXT NOT NULL,
     CONSTRAINT uniqueness UNIQUE (event_id, room_id) ON CONFLICT REPLACE
 );
 
@@ -20,10 +20,11 @@ CREATE INDEX IF NOT EXISTS ev_b_extrem_id ON event_backward_extremities(event_id
 
 
 CREATE TABLE IF NOT EXISTS event_edges(
-    event_id TEXT,
-    prev_event_id TEXT,
-    room_id TEXT,
-    CONSTRAINT uniqueness UNIQUE (event_id, prev_event_id, room_id)
+    event_id TEXT NOT NULL,
+    prev_event_id TEXT NOT NULL,
+    room_id TEXT NOT NULL,
+    is_state INTEGER NOT NULL,
+    CONSTRAINT uniqueness UNIQUE (event_id, prev_event_id, room_id, is_state)
 );
 
 CREATE INDEX IF NOT EXISTS ev_edges_id ON event_edges(event_id);
@@ -31,8 +32,8 @@ CREATE INDEX IF NOT EXISTS ev_edges_prev_id ON event_edges(prev_event_id);
 
 
 CREATE TABLE IF NOT EXISTS room_depth(
-    room_id TEXT,
-    min_depth INTEGER,
+    room_id TEXT NOT NULL,
+    min_depth INTEGER NOT NULL,
     CONSTRAINT uniqueness UNIQUE (room_id)
 );
 
@@ -40,10 +41,25 @@ CREATE INDEX IF NOT EXISTS room_depth_room ON room_depth(room_id);
 
 
 create TABLE IF NOT EXISTS event_destinations(
-    event_id TEXT,
-    destination TEXT,
+    event_id TEXT NOT NULL,
+    destination TEXT NOT NULL,
     delivered_ts INTEGER DEFAULT 0, -- or 0 if not delivered
     CONSTRAINT uniqueness UNIQUE (event_id, destination) ON CONFLICT REPLACE
 );
 
 CREATE INDEX IF NOT EXISTS event_destinations_id ON event_destinations(event_id);
+
+
+CREATE TABLE IF NOT EXISTS state_forward_extremities(
+    event_id TEXT NOT NULL,
+    room_id TEXT NOT NULL,
+    type TEXT NOT NULL,
+    state_key TEXT NOT NULL,
+    CONSTRAINT uniqueness UNIQUE (event_id, room_id) ON CONFLICT REPLACE
+);
+
+CREATE INDEX IF NOT EXISTS st_extrem_keys ON state_forward_extremities(
+    room_id, type, state_key
+);
+CREATE INDEX IF NOT EXISTS st_extrem_id ON state_forward_extremities(event_id);
+
diff --git a/synapse/util/jsonobject.py b/synapse/util/jsonobject.py
index c91eb897a8..e79b68f661 100644
--- a/synapse/util/jsonobject.py
+++ b/synapse/util/jsonobject.py
@@ -80,7 +80,7 @@ class JsonEncodedObject(object):
 
     def get_full_dict(self):
         d = {
-            k: v for (k, v) in self.__dict__.items()
+            k: _encode(v) for (k, v) in self.__dict__.items()
             if k in self.valid_keys or k in self.internal_keys
         }
         d.update(self.unrecognized_keys)