diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 4881f03368..e0d97f440b 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -15,15 +15,14 @@
import logging
from synapse.api.errors import StoreError
-from synapse.api.events.utils import prune_event
+from synapse.events import FrozenEvent
+from synapse.events.utils import prune_event
from synapse.util.logutils import log_function
from synapse.util.logcontext import PreserveLoggingContext, LoggingContext
-from syutil.base64util import encode_base64
from twisted.internet import defer
import collections
-import copy
import json
import sys
import time
@@ -84,7 +83,6 @@ class SQLBaseStore(object):
def __init__(self, hs):
self.hs = hs
self._db_pool = hs.get_db_pool()
- self.event_factory = hs.get_event_factory()
self._clock = hs.get_clock()
@defer.inlineCallbacks
@@ -436,123 +434,79 @@ class SQLBaseStore(object):
return self.runInteraction("_simple_max_id", func)
- def _parse_event_from_row(self, row_dict):
- d = copy.deepcopy({k: v for k, v in row_dict.items()})
-
- d.pop("stream_ordering", None)
- d.pop("topological_ordering", None)
- d.pop("processed", None)
- d["origin_server_ts"] = d.pop("ts", 0)
- replaces_state = d.pop("prev_state", None)
+ def _get_events(self, event_ids):
+ return self.runInteraction(
+ "_get_events", self._get_events_txn, event_ids
+ )
- if replaces_state:
- d["replaces_state"] = replaces_state
+ def _get_events_txn(self, txn, event_ids):
+ events = []
+ for e_id in event_ids:
+ ev = self._get_event_txn(txn, e_id)
- d.update(json.loads(row_dict["unrecognized_keys"]))
- d["content"] = json.loads(d["content"])
- del d["unrecognized_keys"]
+ if ev:
+ events.append(ev)
- if "age_ts" not in d:
- # For compatibility
- d["age_ts"] = d.get("origin_server_ts", 0)
+ return events
- return self.event_factory.create_event(
- etype=d["type"],
- **d
+ def _get_event_txn(self, txn, event_id, check_redacted=True,
+ get_prev_content=True):
+ sql = (
+ "SELECT internal_metadata, json, r.event_id FROM event_json as e "
+ "LEFT JOIN redactions as r ON e.event_id = r.redacts "
+ "WHERE e.event_id = ? "
+ "LIMIT 1 "
)
- def _get_events_txn(self, txn, event_ids):
- # FIXME (erikj): This should be batched?
+ txn.execute(sql, (event_id,))
- sql = "SELECT * FROM events WHERE event_id = ? ORDER BY rowid asc"
+ res = txn.fetchone()
- event_rows = []
- for e_id in event_ids:
- c = txn.execute(sql, (e_id,))
- event_rows.extend(self.cursor_to_dict(c))
+ if not res:
+ return None
- return self._parse_events_txn(txn, event_rows)
+ internal_metadata, js, redacted = res
- def _parse_events(self, rows):
- return self.runInteraction(
- "_parse_events", self._parse_events_txn, rows
- )
+ d = json.loads(js)
+ internal_metadata = json.loads(internal_metadata)
- def _parse_events_txn(self, txn, rows):
- events = [self._parse_event_from_row(r) for r in rows]
+ ev = FrozenEvent(d, internal_metadata_dict=internal_metadata)
- select_event_sql = (
- "SELECT * FROM events WHERE event_id = ? ORDER BY rowid asc"
- )
+ if check_redacted and redacted:
+ ev = prune_event(ev)
+
+ ev.unsigned["redacted_by"] = redacted
+ # Get the redaction event.
- for i, ev in enumerate(events):
- signatures = self._get_event_signatures_txn(
- txn, ev.event_id,
+ because = self._get_event_txn(
+ txn,
+ redacted,
+ check_redacted=False
)
- ev.signatures = {
- n: {
- k: encode_base64(v) for k, v in s.items()
- }
- for n, s in signatures.items()
- }
+ if because:
+ ev.unsigned["redacted_because"] = because
- hashes = self._get_event_content_hashes_txn(
- txn, ev.event_id,
+ if get_prev_content and "replaces_state" in ev.unsigned:
+ prev = self._get_event_txn(
+ txn,
+ ev.unsigned["replaces_state"],
+ get_prev_content=False,
)
+ if prev:
+ ev.unsigned["prev_content"] = prev.get_dict()["content"]
- ev.hashes = {
- k: encode_base64(v) for k, v in hashes.items()
- }
-
- prevs = self._get_prev_events_and_state(txn, ev.event_id)
-
- ev.prev_events = [
- (e_id, h)
- for e_id, h, is_state in prevs
- if is_state == 0
- ]
-
- ev.auth_events = self._get_auth_events(txn, ev.event_id)
-
- if hasattr(ev, "state_key"):
- ev.prev_state = [
- (e_id, h)
- for e_id, h, is_state in prevs
- if is_state == 1
- ]
-
- if hasattr(ev, "replaces_state"):
- # Load previous state_content.
- # FIXME (erikj): Handle multiple prev_states.
- cursor = txn.execute(
- select_event_sql,
- (ev.replaces_state,)
- )
- prevs = self.cursor_to_dict(cursor)
- if prevs:
- prev = self._parse_event_from_row(prevs[0])
- ev.prev_content = prev.content
-
- if not hasattr(ev, "redacted"):
- logger.debug("Doesn't have redacted key: %s", ev)
- ev.redacted = self._has_been_redacted_txn(txn, ev)
-
- if ev.redacted:
- # Get the redaction event.
- select_event_sql = "SELECT * FROM events WHERE event_id = ?"
- txn.execute(select_event_sql, (ev.redacted,))
-
- del_evs = self._parse_events_txn(
- txn, self.cursor_to_dict(txn)
- )
+ return ev
- if del_evs:
- ev = prune_event(ev)
- events[i] = ev
- ev.redacted_because = del_evs[0]
+ def _parse_events(self, rows):
+ return self.runInteraction(
+ "_parse_events", self._parse_events_txn, rows
+ )
- return events
+ def _parse_events_txn(self, txn, rows):
+ event_ids = [r["event_id"] for r in rows]
+
+ return self._get_events_txn(txn, event_ids)
def _has_been_redacted_txn(self, txn, event):
sql = "SELECT event_id FROM redactions WHERE redacts = ?"
@@ -650,7 +604,7 @@ class JoinHelper(object):
to dump the results into.
Attributes:
- taples (list): List of `Table` classes
+ tables (list): List of `Table` classes
EntryType (type)
"""
|