summary refs log tree commit diff
path: root/synapse/storage/_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/_base.py')
-rw-r--r--synapse/storage/_base.py158
1 files changed, 55 insertions, 103 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index eb8cc4a9f3..efb2664680 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -15,15 +15,14 @@
 import logging
 
 from synapse.api.errors import StoreError
-from synapse.api.events.utils import prune_event
+from synapse.events import FrozenEvent
+from synapse.events.utils import prune_event
 from synapse.util.logutils import log_function
 from synapse.util.logcontext import PreserveLoggingContext, LoggingContext
-from syutil.base64util import encode_base64
 
 from twisted.internet import defer
 
 import collections
-import copy
 import json
 import sys
 import time
@@ -84,7 +83,6 @@ class SQLBaseStore(object):
     def __init__(self, hs):
         self.hs = hs
         self._db_pool = hs.get_db_pool()
-        self.event_factory = hs.get_event_factory()
         self._clock = hs.get_clock()
 
     @defer.inlineCallbacks
@@ -481,123 +479,77 @@ class SQLBaseStore(object):
 
         return self.runInteraction("_simple_max_id", func)
 
-    def _parse_event_from_row(self, row_dict):
-        d = copy.deepcopy({k: v for k, v in row_dict.items()})
-
-        d.pop("stream_ordering", None)
-        d.pop("topological_ordering", None)
-        d.pop("processed", None)
-        d["origin_server_ts"] = d.pop("ts", 0)
-        replaces_state = d.pop("prev_state", None)
+    def _get_events(self, event_ids):
+        return self.runInteraction(
+            "_get_events", self._get_events_txn, event_ids
+        )
 
-        if replaces_state:
-            d["replaces_state"] = replaces_state
+    def _get_events_txn(self, txn, event_ids):
+        events = []
+        for e_id in event_ids:
+            ev = self._get_event_txn(txn, e_id)
 
-        d.update(json.loads(row_dict["unrecognized_keys"]))
-        d["content"] = json.loads(d["content"])
-        del d["unrecognized_keys"]
+            if ev:
+                events.append(ev)
 
-        if "age_ts" not in d:
-            # For compatibility
-            d["age_ts"] = d.get("origin_server_ts", 0)
+        return events
 
-        return self.event_factory.create_event(
-            etype=d["type"],
-            **d
+    def _get_event_txn(self, txn, event_id, check_redacted=True,
+                       get_prev_content=True):
+        sql = (
+            "SELECT internal_metadata, json, r.event_id FROM event_json as e "
+            "LEFT JOIN redactions as r ON e.event_id = r.redacts "
+            "WHERE e.event_id = ? "
+            "LIMIT 1 "
         )
 
-    def _get_events_txn(self, txn, event_ids):
-        # FIXME (erikj): This should be batched?
+        txn.execute(sql, (event_id,))
 
-        sql = "SELECT * FROM events WHERE event_id = ? ORDER BY rowid asc"
+        res = txn.fetchone()
 
-        event_rows = []
-        for e_id in event_ids:
-            c = txn.execute(sql, (e_id,))
-            event_rows.extend(self.cursor_to_dict(c))
+        if not res:
+            return None
 
-        return self._parse_events_txn(txn, event_rows)
+        internal_metadata, js, redacted = res
 
-    def _parse_events(self, rows):
-        return self.runInteraction(
-            "_parse_events", self._parse_events_txn, rows
-        )
+        d = json.loads(js)
+        internal_metadata = json.loads(internal_metadata)
 
-    def _parse_events_txn(self, txn, rows):
-        events = [self._parse_event_from_row(r) for r in rows]
+        ev = FrozenEvent(d, internal_metadata_dict=internal_metadata)
 
-        select_event_sql = (
-            "SELECT * FROM events WHERE event_id = ? ORDER BY rowid asc"
-        )
+        if check_redacted and redacted:
+            ev = prune_event(ev)
 
-        for i, ev in enumerate(events):
-            signatures = self._get_event_signatures_txn(
-                txn, ev.event_id,
+            ev.unsigned["redacted_by"] = redacted
+            # Get the redaction event.
+
+            because = self._get_event_txn(
+                txn,
+                redacted,
+                check_redacted=False
             )
 
-            ev.signatures = {
-                n: {
-                    k: encode_base64(v) for k, v in s.items()
-                }
-                for n, s in signatures.items()
-            }
+            if because:
+                ev.unsigned["redacted_because"] = because
 
-            hashes = self._get_event_content_hashes_txn(
-                txn, ev.event_id,
-            )
+        if get_prev_content and "replaces_state" in ev.unsigned:
+            ev.unsigned["prev_content"] = self._get_event_txn(
+                txn,
+                ev.unsigned["replaces_state"],
+                get_prev_content=False,
+            ).get_dict()["content"]
 
-            ev.hashes = {
-                k: encode_base64(v) for k, v in hashes.items()
-            }
-
-            prevs = self._get_prev_events_and_state(txn, ev.event_id)
-
-            ev.prev_events = [
-                (e_id, h)
-                for e_id, h, is_state in prevs
-                if is_state == 0
-            ]
-
-            ev.auth_events = self._get_auth_events(txn, ev.event_id)
-
-            if hasattr(ev, "state_key"):
-                ev.prev_state = [
-                    (e_id, h)
-                    for e_id, h, is_state in prevs
-                    if is_state == 1
-                ]
-
-                if hasattr(ev, "replaces_state"):
-                    # Load previous state_content.
-                    # FIXME (erikj): Handle multiple prev_states.
-                    cursor = txn.execute(
-                        select_event_sql,
-                        (ev.replaces_state,)
-                    )
-                    prevs = self.cursor_to_dict(cursor)
-                    if prevs:
-                        prev = self._parse_event_from_row(prevs[0])
-                        ev.prev_content = prev.content
-
-            if not hasattr(ev, "redacted"):
-                logger.debug("Doesn't have redacted key: %s", ev)
-                ev.redacted = self._has_been_redacted_txn(txn, ev)
-
-            if ev.redacted:
-                # Get the redaction event.
-                select_event_sql = "SELECT * FROM events WHERE event_id = ?"
-                txn.execute(select_event_sql, (ev.redacted,))
-
-                del_evs = self._parse_events_txn(
-                    txn, self.cursor_to_dict(txn)
-                )
+        return ev
 
-                if del_evs:
-                    ev = prune_event(ev)
-                    events[i] = ev
-                    ev.redacted_because = del_evs[0]
+    def _parse_events(self, rows):
+        return self.runInteraction(
+            "_parse_events", self._parse_events_txn, rows
+        )
 
-        return events
+    def _parse_events_txn(self, txn, rows):
+        event_ids = [r["event_id"] for r in rows]
+
+        return self._get_events_txn(txn, event_ids)
 
     def _has_been_redacted_txn(self, txn, event):
         sql = "SELECT event_id FROM redactions WHERE redacts = ?"
@@ -695,7 +647,7 @@ class JoinHelper(object):
     to dump the results into.
 
     Attributes:
-        taples (list): List of `Table` classes
+        tables (list): List of `Table` classes
         EntryType (type)
     """