summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/__init__.py77
-rw-r--r--synapse/storage/_base.py155
-rw-r--r--synapse/storage/event_federation.py11
-rw-r--r--synapse/storage/schema/im.sql12
-rw-r--r--synapse/storage/schema/state.sql3
-rw-r--r--synapse/storage/signatures.py19
-rw-r--r--synapse/storage/state.py13
7 files changed, 141 insertions, 149 deletions
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index c9ab434b4e..e75eaa92d5 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -39,6 +39,7 @@ from .state import StateStore
 from .signatures import SignatureStore
 
 from syutil.base64util import decode_base64
+from syutil.jsonutil import encode_canonical_json
 
 from synapse.crypto.event_signing import compute_event_reference_hash
 
@@ -89,7 +90,6 @@ class DataStore(RoomMemberStore, RoomStore,
 
     def __init__(self, hs):
         super(DataStore, self).__init__(hs)
-        self.event_factory = hs.get_event_factory()
         self.hs = hs
 
         self.min_token_deferred = self._get_min_token()
@@ -97,8 +97,8 @@ class DataStore(RoomMemberStore, RoomStore,
 
     @defer.inlineCallbacks
     @log_function
-    def persist_event(self, event, backfilled=False, is_new_state=True,
-                      current_state=None):
+    def persist_event(self, event, context, backfilled=False,
+                      is_new_state=True, current_state=None):
         stream_ordering = None
         if backfilled:
             if not self.min_token_deferred.called:
@@ -111,6 +111,7 @@ class DataStore(RoomMemberStore, RoomStore,
                 "persist_event",
                 self._persist_event_txn,
                 event=event,
+                context=context,
                 backfilled=backfilled,
                 stream_ordering=stream_ordering,
                 is_new_state=is_new_state,
@@ -121,29 +122,20 @@ class DataStore(RoomMemberStore, RoomStore,
 
     @defer.inlineCallbacks
     def get_event(self, event_id, allow_none=False):
-        events_dict = yield self._simple_select_one(
-            "events",
-            {"event_id": event_id},
-            [
-                "event_id",
-                "type",
-                "room_id",
-                "content",
-                "unrecognized_keys",
-                "depth",
-            ],
-            allow_none=allow_none,
-        )
+        events = yield self._get_events([event_id])
 
-        if not events_dict:
-            defer.returnValue(None)
+        if not events:
+            if allow_none:
+                defer.returnValue(None)
+            else:
+                raise RuntimeError("Could not find event %s" % (event_id,))
 
-        event = yield self._parse_events([events_dict])
-        defer.returnValue(event[0])
+        defer.returnValue(events[0])
 
     @log_function
-    def _persist_event_txn(self, txn, event, backfilled, stream_ordering=None,
-                           is_new_state=True, current_state=None):
+    def _persist_event_txn(self, txn, event, context, backfilled,
+                           stream_ordering=None, is_new_state=True,
+                           current_state=None):
         if event.type == RoomMemberEvent.TYPE:
             self._store_room_member_txn(txn, event)
         elif event.type == FeedbackEvent.TYPE:
@@ -156,15 +148,35 @@ class DataStore(RoomMemberStore, RoomStore,
             self._store_redaction(txn, event)
 
         outlier = False
-        if hasattr(event, "outlier"):
-            outlier = event.outlier
+        if hasattr(event.internal_metadata, "outlier"):
+            outlier = event.internal_metadata.outlier
+
+        event_dict = {
+            k: v
+            for k, v in event.get_dict().items()
+            if k not in [
+                "redacted",
+                "redacted_because",
+            ]
+        }
+
+        self._simple_insert_txn(
+            txn,
+            table="event_json",
+            values={
+                "event_id": event.event_id,
+                "room_id": event.room_id,
+                "json": encode_canonical_json(event_dict).decode("UTF-8"),
+            },
+            or_replace=True,
+        )
 
         vals = {
             "topological_ordering": event.depth,
             "event_id": event.event_id,
             "type": event.type,
             "room_id": event.room_id,
-            "content": json.dumps(event.content),
+            "content": json.dumps(event.get_dict()["content"]),
             "processed": True,
             "outlier": outlier,
             "depth": event.depth,
@@ -175,7 +187,7 @@ class DataStore(RoomMemberStore, RoomStore,
 
         unrec = {
             k: v
-            for k, v in event.get_full_dict().items()
+            for k, v in event.get_dict().items()
             if k not in vals.keys() and k not in [
                 "redacted",
                 "redacted_because",
@@ -210,7 +222,8 @@ class DataStore(RoomMemberStore, RoomStore,
             room_id=event.room_id,
         )
 
-        self._store_state_groups_txn(txn, event)
+        if not outlier:
+            self._store_state_groups_txn(txn, event, context)
 
         if current_state:
             txn.execute(
@@ -304,16 +317,6 @@ class DataStore(RoomMemberStore, RoomStore,
                 txn, event.event_id, hash_alg, hash_bytes,
             )
 
-        if hasattr(event, "signatures"):
-            logger.debug("sigs: %s", event.signatures)
-            for name, sigs in event.signatures.items():
-                for key_id, signature_base64 in sigs.items():
-                    signature_bytes = decode_base64(signature_base64)
-                    self._store_event_signature_txn(
-                        txn, event.event_id, name, key_id,
-                        signature_bytes,
-                    )
-
         for prev_event_id, prev_hashes in event.prev_events:
             for alg, hash_base64 in prev_hashes.items():
                 hash_bytes = decode_base64(hash_base64)
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index e72200e2f7..31d5163c19 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -15,15 +15,14 @@
 import logging
 
 from synapse.api.errors import StoreError
-from synapse.api.events.utils import prune_event
+from synapse.events import FrozenEvent
+from synapse.events.utils import prune_event
 from synapse.util.logutils import log_function
 from synapse.util.logcontext import PreserveLoggingContext, LoggingContext
-from syutil.base64util import encode_base64
 
 from twisted.internet import defer
 
 import collections
-import copy
 import json
 import sys
 import time
@@ -84,7 +83,6 @@ class SQLBaseStore(object):
     def __init__(self, hs):
         self.hs = hs
         self._db_pool = hs.get_db_pool()
-        self.event_factory = hs.get_event_factory()
         self._clock = hs.get_clock()
 
     @defer.inlineCallbacks
@@ -436,123 +434,76 @@ class SQLBaseStore(object):
 
         return self.runInteraction("_simple_max_id", func)
 
-    def _parse_event_from_row(self, row_dict):
-        d = copy.deepcopy({k: v for k, v in row_dict.items()})
-
-        d.pop("stream_ordering", None)
-        d.pop("topological_ordering", None)
-        d.pop("processed", None)
-        d["origin_server_ts"] = d.pop("ts", 0)
-        replaces_state = d.pop("prev_state", None)
+    def _get_events(self, event_ids):
+        return self.runInteraction(
+            "_get_events", self._get_events_txn, event_ids
+        )
 
-        if replaces_state:
-            d["replaces_state"] = replaces_state
+    def _get_events_txn(self, txn, event_ids):
+        events = []
+        for e_id in event_ids:
+            ev = self._get_event_txn(txn, e_id)
 
-        d.update(json.loads(row_dict["unrecognized_keys"]))
-        d["content"] = json.loads(d["content"])
-        del d["unrecognized_keys"]
+            if ev:
+                events.append(ev)
 
-        if "age_ts" not in d:
-            # For compatibility
-            d["age_ts"] = d.get("origin_server_ts", 0)
+        return events
 
-        return self.event_factory.create_event(
-            etype=d["type"],
-            **d
+    def _get_event_txn(self, txn, event_id, check_redacted=True,
+                       get_prev_content=True):
+        sql = (
+            "SELECT json, r.event_id FROM event_json as e "
+            "LEFT JOIN redactions as r ON e.event_id = r.redacts "
+            "WHERE e.event_id = ? "
+            "LIMIT 1 "
         )
 
-    def _get_events_txn(self, txn, event_ids):
-        # FIXME (erikj): This should be batched?
+        txn.execute(sql, (event_id,))
 
-        sql = "SELECT * FROM events WHERE event_id = ? ORDER BY rowid asc"
+        res = txn.fetchone()
 
-        event_rows = []
-        for e_id in event_ids:
-            c = txn.execute(sql, (e_id,))
-            event_rows.extend(self.cursor_to_dict(c))
+        if not res:
+            return None
 
-        return self._parse_events_txn(txn, event_rows)
+        js, redacted = res
 
-    def _parse_events(self, rows):
-        return self.runInteraction(
-            "_parse_events", self._parse_events_txn, rows
-        )
+        d = json.loads(js)
 
-    def _parse_events_txn(self, txn, rows):
-        events = [self._parse_event_from_row(r) for r in rows]
+        ev = FrozenEvent(d)
 
-        select_event_sql = (
-            "SELECT * FROM events WHERE event_id = ? ORDER BY rowid asc"
-        )
+        if check_redacted and redacted:
+            ev = prune_event(ev)
 
-        for i, ev in enumerate(events):
-            signatures = self._get_event_signatures_txn(
-                txn, ev.event_id,
+            ev.unsigned["redacted_by"] = redacted
+            # Get the redaction event.
+
+            because = self._get_event_txn(
+                txn,
+                redacted,
+                check_redacted=False
             )
 
-            ev.signatures = {
-                n: {
-                    k: encode_base64(v) for k, v in s.items()
-                }
-                for n, s in signatures.items()
-            }
+            if because:
+                ev.unsigned["redacted_because"] = because
 
-            hashes = self._get_event_content_hashes_txn(
-                txn, ev.event_id,
-            )
+        if get_prev_content and "replaces_state" in ev.unsigned:
+            ev.unsigned["prev_content"] = self._get_event_txn(
+                txn,
+                ev.unsigned["replaces_state"],
+                get_prev_content=False,
+            ).get_dict()["content"]
 
-            ev.hashes = {
-                k: encode_base64(v) for k, v in hashes.items()
-            }
-
-            prevs = self._get_prev_events_and_state(txn, ev.event_id)
-
-            ev.prev_events = [
-                (e_id, h)
-                for e_id, h, is_state in prevs
-                if is_state == 0
-            ]
-
-            ev.auth_events = self._get_auth_events(txn, ev.event_id)
-
-            if hasattr(ev, "state_key"):
-                ev.prev_state = [
-                    (e_id, h)
-                    for e_id, h, is_state in prevs
-                    if is_state == 1
-                ]
-
-                if hasattr(ev, "replaces_state"):
-                    # Load previous state_content.
-                    # FIXME (erikj): Handle multiple prev_states.
-                    cursor = txn.execute(
-                        select_event_sql,
-                        (ev.replaces_state,)
-                    )
-                    prevs = self.cursor_to_dict(cursor)
-                    if prevs:
-                        prev = self._parse_event_from_row(prevs[0])
-                        ev.prev_content = prev.content
-
-            if not hasattr(ev, "redacted"):
-                logger.debug("Doesn't have redacted key: %s", ev)
-                ev.redacted = self._has_been_redacted_txn(txn, ev)
-
-            if ev.redacted:
-                # Get the redaction event.
-                select_event_sql = "SELECT * FROM events WHERE event_id = ?"
-                txn.execute(select_event_sql, (ev.redacted,))
-
-                del_evs = self._parse_events_txn(
-                    txn, self.cursor_to_dict(txn)
-                )
+        return ev
 
-                if del_evs:
-                    ev = prune_event(ev)
-                    events[i] = ev
-                    ev.redacted_because = del_evs[0]
+    def _parse_events(self, rows):
+        return self.runInteraction(
+            "_parse_events", self._parse_events_txn, rows
+        )
 
-        return events
+    def _parse_events_txn(self, txn, rows):
+        event_ids = [r["event_id"] for r in rows]
+
+        return self._get_events_txn(txn, event_ids)
 
     def _has_been_redacted_txn(self, txn, event):
         sql = "SELECT event_id FROM redactions WHERE redacts = ?"
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index 6c559f8f63..ced066f407 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -177,14 +177,15 @@ class EventFederationStore(SQLBaseStore):
             retcols=["prev_event_id", "is_state"],
         )
 
+        hashes = self._get_prev_event_hashes_txn(txn, event_id)
+
         results = []
         for d in res:
-            hashes = self._get_event_reference_hashes_txn(
-                txn,
-                d["prev_event_id"]
-            )
+            edge_hash = self._get_event_reference_hashes_txn(txn, d["prev_event_id"])
+            edge_hash.update(hashes.get(d["prev_event_id"], {}))
             prev_hashes = {
-                k: encode_base64(v) for k, v in hashes.items()
+                k: encode_base64(v)
+                for k, v in edge_hash.items()
                 if k == "sha256"
             }
             results.append((d["prev_event_id"], prev_hashes, d["is_state"]))
diff --git a/synapse/storage/schema/im.sql b/synapse/storage/schema/im.sql
index 8ba732a23b..0300bb29e1 100644
--- a/synapse/storage/schema/im.sql
+++ b/synapse/storage/schema/im.sql
@@ -32,6 +32,18 @@ CREATE INDEX IF NOT EXISTS events_stream_ordering ON events (stream_ordering);
 CREATE INDEX IF NOT EXISTS events_topological_ordering ON events (topological_ordering);
 CREATE INDEX IF NOT EXISTS events_room_id ON events (room_id);
 
+
+CREATE TABLE IF NOT EXISTS event_json(
+    event_id TEXT NOT NULL,
+    room_id TEXT NOT NULL,
+    json BLOB NOT NULL,
+    CONSTRAINT ev_j_uniq UNIQUE (event_id)
+);
+
+CREATE INDEX IF NOT EXISTS event_json_id ON event_json(event_id);
+CREATE INDEX IF NOT EXISTS event_json_room_id ON event_json(room_id);
+
+
 CREATE TABLE IF NOT EXISTS state_events(
     event_id TEXT NOT NULL,
     room_id TEXT NOT NULL,
diff --git a/synapse/storage/schema/state.sql b/synapse/storage/schema/state.sql
index 44f7aafb27..2c48d6daca 100644
--- a/synapse/storage/schema/state.sql
+++ b/synapse/storage/schema/state.sql
@@ -29,7 +29,8 @@ CREATE TABLE IF NOT EXISTS state_groups_state(
 
 CREATE TABLE IF NOT EXISTS event_to_state_groups(
     event_id TEXT NOT NULL,
-    state_group INTEGER NOT NULL
+    state_group INTEGER NOT NULL,
+    CONSTRAINT event_to_state_groups_uniq UNIQUE (event_id)
 );
 
 CREATE INDEX IF NOT EXISTS state_groups_id ON state_groups(id);
diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py
index eea4f21065..3a705119fd 100644
--- a/synapse/storage/signatures.py
+++ b/synapse/storage/signatures.py
@@ -13,8 +13,12 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from twisted.internet import defer
+
 from _base import SQLBaseStore
 
+from syutil.base64util import encode_base64
+
 
 class SignatureStore(SQLBaseStore):
     """Persistence for event signatures and hashes"""
@@ -67,6 +71,21 @@ class SignatureStore(SQLBaseStore):
             f
         )
 
+    @defer.inlineCallbacks
+    def add_event_hashes(self, event_ids):
+        hashes = yield self.get_event_reference_hashes(
+            event_ids
+        )
+        hashes = [
+            {
+                k: encode_base64(v) for k, v in h.items()
+                if k == "sha256"
+            }
+            for h in hashes
+        ]
+
+        defer.returnValue(zip(event_ids, hashes))
+
     def _get_event_reference_hashes_txn(self, txn, event_id):
         """Get all the hashes for a given PDU.
         Args:
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index e0f44b3e59..afe3e5edea 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -86,11 +86,16 @@ class StateStore(SQLBaseStore):
             self._store_state_groups_txn, event
         )
 
-    def _store_state_groups_txn(self, txn, event):
-        if event.state_events is None:
+    def _store_state_groups_txn(self, txn, event, context):
+        if context.current_state is None:
             return
 
-        state_group = event.state_group
+        state_events = context.current_state
+
+        if event.is_state():
+            state_events[(event.type, event.state_key)] = event
+
+        state_group = context.state_group
         if not state_group:
             state_group = self._simple_insert_txn(
                 txn,
@@ -102,7 +107,7 @@ class StateStore(SQLBaseStore):
                 or_ignore=True,
             )
 
-            for state in event.state_events.values():
+            for state in state_events.values():
                 self._simple_insert_txn(
                     txn,
                     table="state_groups_state",