summary refs log tree commit diff
path: root/synapse/storage/events.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/events.py')
-rw-r--r--synapse/storage/events.py132
1 files changed, 132 insertions, 0 deletions
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 971f3211ac..4fee155234 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -17,6 +17,9 @@ from _base import SQLBaseStore, _RollbackButIsFineException
 
 from twisted.internet import defer
 
+from synapse.events import FrozenEvent
+from synapse.events.utils import prune_event
+
 from synapse.util.logutils import log_function
 from synapse.api.constants import EventTypes
 from synapse.crypto.event_signing import compute_event_reference_hash
@@ -26,6 +29,7 @@ from syutil.jsonutil import encode_canonical_json
 from contextlib import contextmanager
 
 import logging
+import simplejson as json
 
 logger = logging.getLogger(__name__)
 
@@ -396,3 +400,131 @@ class EventsStore(SQLBaseStore):
         return self.runInteraction(
             "have_events", f,
         )
+
+    def _get_events(self, event_ids, check_redacted=True,
+                    get_prev_content=False):
+        return self.runInteraction(
+            "_get_events", self._get_events_txn, event_ids,
+            check_redacted=check_redacted, get_prev_content=get_prev_content,
+        )
+
+    def _get_events_txn(self, txn, event_ids, check_redacted=True,
+                        get_prev_content=False):
+        if not event_ids:
+            return []
+
+        events = [
+            self._get_event_txn(
+                txn, event_id,
+                check_redacted=check_redacted,
+                get_prev_content=get_prev_content
+            )
+            for event_id in event_ids
+        ]
+
+        return [e for e in events if e]
+
+    def _invalidate_get_event_cache(self, event_id):
+        for check_redacted in (False, True):
+            for get_prev_content in (False, True):
+                self._get_event_cache.invalidate(event_id, check_redacted,
+                                                 get_prev_content)
+
+    def _get_event_txn(self, txn, event_id, check_redacted=True,
+                       get_prev_content=False, allow_rejected=False):
+
+        try:
+            ret = self._get_event_cache.get(event_id, check_redacted, get_prev_content)
+
+            if allow_rejected or not ret.rejected_reason:
+                return ret
+            else:
+                return None
+        except KeyError:
+            pass
+
+        sql = (
+            "SELECT e.internal_metadata, e.json, r.event_id, rej.reason "
+            "FROM event_json as e "
+            "LEFT JOIN redactions as r ON e.event_id = r.redacts "
+            "LEFT JOIN rejections as rej on rej.event_id = e.event_id  "
+            "WHERE e.event_id = ? "
+            "LIMIT 1 "
+        )
+
+        txn.execute(sql, (event_id,))
+
+        res = txn.fetchone()
+
+        if not res:
+            return None
+
+        internal_metadata, js, redacted, rejected_reason = res
+
+        result = self._get_event_from_row_txn(
+            txn, internal_metadata, js, redacted,
+            check_redacted=check_redacted,
+            get_prev_content=get_prev_content,
+            rejected_reason=rejected_reason,
+        )
+        self._get_event_cache.prefill(event_id, check_redacted, get_prev_content, result)
+
+        if allow_rejected or not rejected_reason:
+            return result
+        else:
+            return None
+
+    def _get_event_from_row_txn(self, txn, internal_metadata, js, redacted,
+                                check_redacted=True, get_prev_content=False,
+                                rejected_reason=None):
+
+        d = json.loads(js)
+        internal_metadata = json.loads(internal_metadata)
+
+        ev = FrozenEvent(
+            d,
+            internal_metadata_dict=internal_metadata,
+            rejected_reason=rejected_reason,
+        )
+
+        if check_redacted and redacted:
+            ev = prune_event(ev)
+
+            ev.unsigned["redacted_by"] = redacted
+            # Get the redaction event.
+
+            because = self._get_event_txn(
+                txn,
+                redacted,
+                check_redacted=False
+            )
+
+            if because:
+                ev.unsigned["redacted_because"] = because
+
+        if get_prev_content and "replaces_state" in ev.unsigned:
+            prev = self._get_event_txn(
+                txn,
+                ev.unsigned["replaces_state"],
+                get_prev_content=False,
+            )
+            if prev:
+                ev.unsigned["prev_content"] = prev.get_dict()["content"]
+
+        return ev
+
+    def _parse_events(self, rows):
+        return self.runInteraction(
+            "_parse_events", self._parse_events_txn, rows
+        )
+
+    def _parse_events_txn(self, txn, rows):
+        event_ids = [r["event_id"] for r in rows]
+
+        return self._get_events_txn(txn, event_ids)
+
+    def _has_been_redacted_txn(self, txn, event):
+        sql = "SELECT event_id FROM redactions WHERE redacts = ?"
+        txn.execute(sql, (event.event_id,))
+        result = txn.fetchone()
+        return result[0] if result else None