summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/__init__.py74
-rw-r--r--synapse/storage/_base.py104
-rw-r--r--synapse/storage/schema/im.sql10
-rw-r--r--synapse/storage/signatures.py19
-rw-r--r--synapse/storage/state.py13
5 files changed, 107 insertions, 113 deletions
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 04ab39341d..7068424441 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -21,6 +21,7 @@ from synapse.api.events.room import (
 )
 
 from synapse.util.logutils import log_function
+from synapse.util.frozenutils import FrozenEncoder
 
 from .directory import DirectoryStore
 from .feedback import FeedbackStore
@@ -93,8 +94,8 @@ class DataStore(RoomMemberStore, RoomStore,
 
     @defer.inlineCallbacks
     @log_function
-    def persist_event(self, event, backfilled=False, is_new_state=True,
-                      current_state=None):
+    def persist_event(self, event, context, backfilled=False,
+                      is_new_state=True, current_state=None):
         stream_ordering = None
         if backfilled:
             if not self.min_token_deferred.called:
@@ -107,6 +108,7 @@ class DataStore(RoomMemberStore, RoomStore,
                 "persist_event",
                 self._persist_event_txn,
                 event=event,
+                context=context,
                 backfilled=backfilled,
                 stream_ordering=stream_ordering,
                 is_new_state=is_new_state,
@@ -117,29 +119,20 @@ class DataStore(RoomMemberStore, RoomStore,
 
     @defer.inlineCallbacks
     def get_event(self, event_id, allow_none=False):
-        events_dict = yield self._simple_select_one(
-            "events",
-            {"event_id": event_id},
-            [
-                "event_id",
-                "type",
-                "room_id",
-                "content",
-                "unrecognized_keys",
-                "depth",
-            ],
-            allow_none=allow_none,
-        )
+        events = yield self._get_events([event_id])
 
-        if not events_dict:
-            defer.returnValue(None)
+        if not events:
+            if allow_none:
+                defer.returnValue(None)
+            else:
+                raise RuntimeError("Could not find event %s" % (event_id,))
 
-        event = yield self._parse_events([events_dict])
-        defer.returnValue(event[0])
+        defer.returnValue(events[0])
 
     @log_function
-    def _persist_event_txn(self, txn, event, backfilled, stream_ordering=None,
-                           is_new_state=True, current_state=None):
+    def _persist_event_txn(self, txn, event, context, backfilled,
+                           stream_ordering=None, is_new_state=True,
+                           current_state=None):
         if event.type == RoomMemberEvent.TYPE:
             self._store_room_member_txn(txn, event)
         elif event.type == FeedbackEvent.TYPE:
@@ -152,15 +145,34 @@ class DataStore(RoomMemberStore, RoomStore,
             self._store_redaction(txn, event)
 
         outlier = False
-        if hasattr(event, "outlier"):
-            outlier = event.outlier
+        if hasattr(event.internal_metadata, "outlier"):
+            outlier = event.internal_metadata.outlier
+
+        event_dict = {
+            k: v
+            for k, v in event.get_dict().items()
+            if k not in [
+                "redacted",
+                "redacted_because",
+            ]
+        }
+
+        self._simple_insert_txn(
+            txn,
+            table="event_json",
+            values={
+                "event_id": event.event_id,
+                "json": json.dumps(event_dict, separators=(',', ':')),
+            },
+            or_replace=True,
+        )
 
         vals = {
             "topological_ordering": event.depth,
             "event_id": event.event_id,
             "type": event.type,
             "room_id": event.room_id,
-            "content": json.dumps(event.content),
+            "content": json.dumps(event.content, cls=FrozenEncoder),
             "processed": True,
             "outlier": outlier,
             "depth": event.depth,
@@ -171,7 +183,7 @@ class DataStore(RoomMemberStore, RoomStore,
 
         unrec = {
             k: v
-            for k, v in event.get_full_dict().items()
+            for k, v in event.get_dict().items()
             if k not in vals.keys() and k not in [
                 "redacted",
                 "redacted_because",
@@ -206,7 +218,7 @@ class DataStore(RoomMemberStore, RoomStore,
             room_id=event.room_id,
         )
 
-        self._store_state_groups_txn(txn, event)
+        self._store_state_groups_txn(txn, event, context)
 
         if current_state:
             txn.execute(
@@ -300,16 +312,6 @@ class DataStore(RoomMemberStore, RoomStore,
                 txn, event.event_id, hash_alg, hash_bytes,
             )
 
-        if hasattr(event, "signatures"):
-            logger.debug("sigs: %s", event.signatures)
-            for name, sigs in event.signatures.items():
-                for key_id, signature_base64 in sigs.items():
-                    signature_bytes = decode_base64(signature_base64)
-                    self._store_event_signature_txn(
-                        txn, event.event_id, name, key_id,
-                        signature_bytes,
-                    )
-
         for prev_event_id, prev_hashes in event.prev_events:
             for alg, hash_base64 in prev_hashes.items():
                 hash_bytes = decode_base64(hash_base64)
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index e72200e2f7..935a167f6f 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -15,10 +15,10 @@
 import logging
 
 from synapse.api.errors import StoreError
-from synapse.api.events.utils import prune_event
+from synapse.events import FrozenEvent
+from synapse.events.utils import prune_event
 from synapse.util.logutils import log_function
 from synapse.util.logcontext import PreserveLoggingContext, LoggingContext
-from syutil.base64util import encode_base64
 
 from twisted.internet import defer
 
@@ -461,84 +461,31 @@ class SQLBaseStore(object):
             **d
         )
 
-    def _get_events_txn(self, txn, event_ids):
-        # FIXME (erikj): This should be batched?
-
-        sql = "SELECT * FROM events WHERE event_id = ? ORDER BY rowid asc"
-
-        event_rows = []
-        for e_id in event_ids:
-            c = txn.execute(sql, (e_id,))
-            event_rows.extend(self.cursor_to_dict(c))
-
-        return self._parse_events_txn(txn, event_rows)
-
-    def _parse_events(self, rows):
+    def _get_events(self, event_ids):
         return self.runInteraction(
-            "_parse_events", self._parse_events_txn, rows
-        )
-
-    def _parse_events_txn(self, txn, rows):
-        events = [self._parse_event_from_row(r) for r in rows]
-
-        select_event_sql = (
-            "SELECT * FROM events WHERE event_id = ? ORDER BY rowid asc"
+            "_get_events", self._get_events_txn, event_ids
         )
 
-        for i, ev in enumerate(events):
-            signatures = self._get_event_signatures_txn(
-                txn, ev.event_id,
+    def _get_events_txn(self, txn, event_ids):
+        events = []
+        for e_id in event_ids:
+            js = self._simple_select_one_onecol_txn(
+                txn,
+                table="event_json",
+                keyvalues={"event_id": e_id},
+                retcol="json",
+                allow_none=True,
             )
 
-            ev.signatures = {
-                n: {
-                    k: encode_base64(v) for k, v in s.items()
-                }
-                for n, s in signatures.items()
-            }
-
-            hashes = self._get_event_content_hashes_txn(
-                txn, ev.event_id,
-            )
+            if not js:
+                # FIXME (erikj): What should we actually do here?
+                continue
 
-            ev.hashes = {
-                k: encode_base64(v) for k, v in hashes.items()
-            }
-
-            prevs = self._get_prev_events_and_state(txn, ev.event_id)
-
-            ev.prev_events = [
-                (e_id, h)
-                for e_id, h, is_state in prevs
-                if is_state == 0
-            ]
-
-            ev.auth_events = self._get_auth_events(txn, ev.event_id)
-
-            if hasattr(ev, "state_key"):
-                ev.prev_state = [
-                    (e_id, h)
-                    for e_id, h, is_state in prevs
-                    if is_state == 1
-                ]
-
-                if hasattr(ev, "replaces_state"):
-                    # Load previous state_content.
-                    # FIXME (erikj): Handle multiple prev_states.
-                    cursor = txn.execute(
-                        select_event_sql,
-                        (ev.replaces_state,)
-                    )
-                    prevs = self.cursor_to_dict(cursor)
-                    if prevs:
-                        prev = self._parse_event_from_row(prevs[0])
-                        ev.prev_content = prev.content
+            d = json.loads(js)
 
-            if not hasattr(ev, "redacted"):
-                logger.debug("Doesn't have redacted key: %s", ev)
-                ev.redacted = self._has_been_redacted_txn(txn, ev)
+            ev = FrozenEvent(d)
 
-            if ev.redacted:
+            if hasattr(ev, "redacted") and ev.redacted:
                 # Get the redaction event.
                 select_event_sql = "SELECT * FROM events WHERE event_id = ?"
                 txn.execute(select_event_sql, (ev.redacted,))
@@ -549,11 +496,22 @@ class SQLBaseStore(object):
 
                 if del_evs:
                     ev = prune_event(ev)
-                    events[i] = ev
                     ev.redacted_because = del_evs[0]
 
+            events.append(ev)
+
         return events
 
+    def _parse_events(self, rows):
+        return self.runInteraction(
+            "_parse_events", self._parse_events_txn, rows
+        )
+
+    def _parse_events_txn(self, txn, rows):
+        event_ids = [r["event_id"] for r in rows]
+
+        return self._get_events_txn(txn, event_ids)
+
     def _has_been_redacted_txn(self, txn, event):
         sql = "SELECT event_id FROM redactions WHERE redacts = ?"
         txn.execute(sql, (event.event_id,))
diff --git a/synapse/storage/schema/im.sql b/synapse/storage/schema/im.sql
index 8ba732a23b..cb0c494ddf 100644
--- a/synapse/storage/schema/im.sql
+++ b/synapse/storage/schema/im.sql
@@ -32,6 +32,16 @@ CREATE INDEX IF NOT EXISTS events_stream_ordering ON events (stream_ordering);
 CREATE INDEX IF NOT EXISTS events_topological_ordering ON events (topological_ordering);
 CREATE INDEX IF NOT EXISTS events_room_id ON events (room_id);
 
+
+CREATE TABLE IF NOT EXISTS event_json(
+    event_id TEXT NOT NULL,
+    json BLOB NOT NULL,
+    CONSTRAINT ev_j_uniq UNIQUE (event_id)
+);
+
+CREATE INDEX IF NOT EXISTS event_json_id ON event_json(event_id);
+
+
 CREATE TABLE IF NOT EXISTS state_events(
     event_id TEXT NOT NULL,
     room_id TEXT NOT NULL,
diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py
index eea4f21065..3a705119fd 100644
--- a/synapse/storage/signatures.py
+++ b/synapse/storage/signatures.py
@@ -13,8 +13,12 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from twisted.internet import defer
+
 from _base import SQLBaseStore
 
+from syutil.base64util import encode_base64
+
 
 class SignatureStore(SQLBaseStore):
     """Persistence for event signatures and hashes"""
@@ -67,6 +71,21 @@ class SignatureStore(SQLBaseStore):
             f
         )
 
+    @defer.inlineCallbacks
+    def add_event_hashes(self, event_ids):
+        hashes = yield self.get_event_reference_hashes(
+            event_ids
+        )
+        hashes = [
+            {
+                k: encode_base64(v) for k, v in h.items()
+                if k == "sha256"
+            }
+            for h in hashes
+        ]
+
+        defer.returnValue(zip(event_ids, hashes))
+
     def _get_event_reference_hashes_txn(self, txn, event_id):
         """Get all the hashes for a given PDU.
         Args:
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index e0f44b3e59..afe3e5edea 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -86,11 +86,16 @@ class StateStore(SQLBaseStore):
             self._store_state_groups_txn, event
         )
 
-    def _store_state_groups_txn(self, txn, event):
-        if event.state_events is None:
+    def _store_state_groups_txn(self, txn, event, context):
+        if context.current_state is None:
             return
 
-        state_group = event.state_group
+        state_events = context.current_state
+
+        if event.is_state():
+            state_events[(event.type, event.state_key)] = event
+
+        state_group = context.state_group
         if not state_group:
             state_group = self._simple_insert_txn(
                 txn,
@@ -102,7 +107,7 @@ class StateStore(SQLBaseStore):
                 or_ignore=True,
             )
 
-            for state in event.state_events.values():
+            for state in state_events.values():
                 self._simple_insert_txn(
                     txn,
                     table="state_groups_state",