diff options
Diffstat (limited to 'synapse/storage/_base.py')
-rw-r--r-- | synapse/storage/_base.py | 158 |
1 files changed, 56 insertions, 102 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 4881f03368..e0d97f440b 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -15,15 +15,14 @@ import logging from synapse.api.errors import StoreError -from synapse.api.events.utils import prune_event +from synapse.events import FrozenEvent +from synapse.events.utils import prune_event from synapse.util.logutils import log_function from synapse.util.logcontext import PreserveLoggingContext, LoggingContext -from syutil.base64util import encode_base64 from twisted.internet import defer import collections -import copy import json import sys import time @@ -84,7 +83,6 @@ class SQLBaseStore(object): def __init__(self, hs): self.hs = hs self._db_pool = hs.get_db_pool() - self.event_factory = hs.get_event_factory() self._clock = hs.get_clock() @defer.inlineCallbacks @@ -436,123 +434,79 @@ class SQLBaseStore(object): return self.runInteraction("_simple_max_id", func) - def _parse_event_from_row(self, row_dict): - d = copy.deepcopy({k: v for k, v in row_dict.items()}) - - d.pop("stream_ordering", None) - d.pop("topological_ordering", None) - d.pop("processed", None) - d["origin_server_ts"] = d.pop("ts", 0) - replaces_state = d.pop("prev_state", None) + def _get_events(self, event_ids): + return self.runInteraction( + "_get_events", self._get_events_txn, event_ids + ) - if replaces_state: - d["replaces_state"] = replaces_state + def _get_events_txn(self, txn, event_ids): + events = [] + for e_id in event_ids: + ev = self._get_event_txn(txn, e_id) - d.update(json.loads(row_dict["unrecognized_keys"])) - d["content"] = json.loads(d["content"]) - del d["unrecognized_keys"] + if ev: + events.append(ev) - if "age_ts" not in d: - # For compatibility - d["age_ts"] = d.get("origin_server_ts", 0) + return events - return self.event_factory.create_event( - etype=d["type"], - **d + def _get_event_txn(self, txn, event_id, check_redacted=True, + get_prev_content=True): + sql = ( + "SELECT internal_metadata, json, r.event_id FROM event_json as e " + "LEFT JOIN redactions as r ON e.event_id = r.redacts " + "WHERE e.event_id = ? " + "LIMIT 1 " ) - def _get_events_txn(self, txn, event_ids): - # FIXME (erikj): This should be batched? + txn.execute(sql, (event_id,)) - sql = "SELECT * FROM events WHERE event_id = ? ORDER BY rowid asc" + res = txn.fetchone() - event_rows = [] - for e_id in event_ids: - c = txn.execute(sql, (e_id,)) - event_rows.extend(self.cursor_to_dict(c)) + if not res: + return None - return self._parse_events_txn(txn, event_rows) + internal_metadata, js, redacted = res - def _parse_events(self, rows): - return self.runInteraction( - "_parse_events", self._parse_events_txn, rows - ) + d = json.loads(js) + internal_metadata = json.loads(internal_metadata) - def _parse_events_txn(self, txn, rows): - events = [self._parse_event_from_row(r) for r in rows] + ev = FrozenEvent(d, internal_metadata_dict=internal_metadata) - select_event_sql = ( - "SELECT * FROM events WHERE event_id = ? ORDER BY rowid asc" - ) + if check_redacted and redacted: + ev = prune_event(ev) + + ev.unsigned["redacted_by"] = redacted + # Get the redaction event. - for i, ev in enumerate(events): - signatures = self._get_event_signatures_txn( - txn, ev.event_id, + because = self._get_event_txn( + txn, + redacted, + check_redacted=False ) - ev.signatures = { - n: { - k: encode_base64(v) for k, v in s.items() - } - for n, s in signatures.items() - } + if because: + ev.unsigned["redacted_because"] = because - hashes = self._get_event_content_hashes_txn( - txn, ev.event_id, + if get_prev_content and "replaces_state" in ev.unsigned: + prev = self._get_event_txn( + txn, + ev.unsigned["replaces_state"], + get_prev_content=False, ) + if prev: + ev.unsigned["prev_content"] = prev.get_dict()["content"] - ev.hashes = { - k: encode_base64(v) for k, v in hashes.items() - } - - prevs = self._get_prev_events_and_state(txn, ev.event_id) - - ev.prev_events = [ - (e_id, h) - for e_id, h, is_state in prevs - if is_state == 0 - ] - - ev.auth_events = self._get_auth_events(txn, ev.event_id) - - if hasattr(ev, "state_key"): - ev.prev_state = [ - (e_id, h) - for e_id, h, is_state in prevs - if is_state == 1 - ] - - if hasattr(ev, "replaces_state"): - # Load previous state_content. - # FIXME (erikj): Handle multiple prev_states. - cursor = txn.execute( - select_event_sql, - (ev.replaces_state,) - ) - prevs = self.cursor_to_dict(cursor) - if prevs: - prev = self._parse_event_from_row(prevs[0]) - ev.prev_content = prev.content - - if not hasattr(ev, "redacted"): - logger.debug("Doesn't have redacted key: %s", ev) - ev.redacted = self._has_been_redacted_txn(txn, ev) - - if ev.redacted: - # Get the redaction event. - select_event_sql = "SELECT * FROM events WHERE event_id = ?" - txn.execute(select_event_sql, (ev.redacted,)) - - del_evs = self._parse_events_txn( - txn, self.cursor_to_dict(txn) - ) + return ev - if del_evs: - ev = prune_event(ev) - events[i] = ev - ev.redacted_because = del_evs[0] + def _parse_events(self, rows): + return self.runInteraction( + "_parse_events", self._parse_events_txn, rows + ) - return events + def _parse_events_txn(self, txn, rows): + event_ids = [r["event_id"] for r in rows] + + return self._get_events_txn(txn, event_ids) def _has_been_redacted_txn(self, txn, event): sql = "SELECT event_id FROM redactions WHERE redacts = ?" @@ -650,7 +604,7 @@ class JoinHelper(object): to dump the results into. Attributes: - taples (list): List of `Table` classes + tables (list): List of `Table` classes EntryType (type) """ |