diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 81052409b7..ec80169c5b 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -15,8 +15,6 @@
import logging
from synapse.api.errors import StoreError
-from synapse.events import FrozenEvent
-from synapse.events.utils import prune_event
from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_context_over_fn, LoggingContext
from synapse.util.lrucache import LruCache
@@ -28,7 +26,6 @@ from twisted.internet import defer
from collections import namedtuple, OrderedDict
import functools
-import simplejson as json
import sys
import time
import threading
@@ -867,158 +864,6 @@ class SQLBaseStore(object):
return self.runInteraction("_simple_max_id", func)
- def _get_events(self, event_ids, check_redacted=True,
- get_prev_content=False):
- return self.runInteraction(
- "_get_events", self._get_events_txn, event_ids,
- check_redacted=check_redacted, get_prev_content=get_prev_content,
- )
-
- def _get_events_txn(self, txn, event_ids, check_redacted=True,
- get_prev_content=False):
- if not event_ids:
- return []
-
- events = [
- self._get_event_txn(
- txn, event_id,
- check_redacted=check_redacted,
- get_prev_content=get_prev_content
- )
- for event_id in event_ids
- ]
-
- return [e for e in events if e]
-
- def _invalidate_get_event_cache(self, event_id):
- for check_redacted in (False, True):
- for get_prev_content in (False, True):
- self._get_event_cache.invalidate(event_id, check_redacted,
- get_prev_content)
-
- def _get_event_txn(self, txn, event_id, check_redacted=True,
- get_prev_content=False, allow_rejected=False):
-
- start_time = time.time() * 1000
-
- def update_counter(desc, last_time):
- curr_time = self._get_event_counters.update(desc, last_time)
- sql_getevents_timer.inc_by(curr_time - last_time, desc)
- return curr_time
-
- try:
- ret = self._get_event_cache.get(event_id, check_redacted, get_prev_content)
-
- if allow_rejected or not ret.rejected_reason:
- return ret
- else:
- return None
- except KeyError:
- pass
- finally:
- start_time = update_counter("event_cache", start_time)
-
- sql = (
- "SELECT e.internal_metadata, e.json, r.event_id, rej.reason "
- "FROM event_json as e "
- "LEFT JOIN redactions as r ON e.event_id = r.redacts "
- "LEFT JOIN rejections as rej on rej.event_id = e.event_id "
- "WHERE e.event_id = ? "
- "LIMIT 1 "
- )
-
- txn.execute(sql, (event_id,))
-
- res = txn.fetchone()
-
- if not res:
- return None
-
- internal_metadata, js, redacted, rejected_reason = res
-
- start_time = update_counter("select_event", start_time)
-
- result = self._get_event_from_row_txn(
- txn, internal_metadata, js, redacted,
- check_redacted=check_redacted,
- get_prev_content=get_prev_content,
- rejected_reason=rejected_reason,
- )
- self._get_event_cache.prefill(event_id, check_redacted, get_prev_content, result)
-
- if allow_rejected or not rejected_reason:
- return result
- else:
- return None
-
- def _get_event_from_row_txn(self, txn, internal_metadata, js, redacted,
- check_redacted=True, get_prev_content=False,
- rejected_reason=None):
-
- start_time = time.time() * 1000
-
- def update_counter(desc, last_time):
- curr_time = self._get_event_counters.update(desc, last_time)
- sql_getevents_timer.inc_by(curr_time - last_time, desc)
- return curr_time
-
- d = json.loads(js)
- start_time = update_counter("decode_json", start_time)
-
- internal_metadata = json.loads(internal_metadata)
- start_time = update_counter("decode_internal", start_time)
-
- ev = FrozenEvent(
- d,
- internal_metadata_dict=internal_metadata,
- rejected_reason=rejected_reason,
- )
- start_time = update_counter("build_frozen_event", start_time)
-
- if check_redacted and redacted:
- ev = prune_event(ev)
-
- ev.unsigned["redacted_by"] = redacted
- # Get the redaction event.
-
- because = self._get_event_txn(
- txn,
- redacted,
- check_redacted=False
- )
-
- if because:
- ev.unsigned["redacted_because"] = because
- start_time = update_counter("redact_event", start_time)
-
- if get_prev_content and "replaces_state" in ev.unsigned:
- prev = self._get_event_txn(
- txn,
- ev.unsigned["replaces_state"],
- get_prev_content=False,
- )
- if prev:
- ev.unsigned["prev_content"] = prev.get_dict()["content"]
- start_time = update_counter("get_prev_content", start_time)
-
- return ev
-
- def _parse_events(self, rows):
- return self.runInteraction(
- "_parse_events", self._parse_events_txn, rows
- )
-
- def _parse_events_txn(self, txn, rows):
- event_ids = [r["event_id"] for r in rows]
-
- return self._get_events_txn(txn, event_ids)
-
- def _has_been_redacted_txn(self, txn, event):
- sql = "SELECT event_id FROM redactions WHERE redacts = ?"
- txn.execute(sql, (event.event_id,))
- result = txn.fetchone()
- return result[0] if result else None
-
def get_next_stream_id(self):
with self._next_stream_id_lock:
i = self._next_stream_id
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 9242b0a84e..afdf0f7193 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -17,6 +17,9 @@ from _base import SQLBaseStore, _RollbackButIsFineException
from twisted.internet import defer
+from synapse.events import FrozenEvent
+from synapse.events.utils import prune_event
+
from synapse.util.logutils import log_function
from synapse.api.constants import EventTypes
from synapse.crypto.event_signing import compute_event_reference_hash
@@ -26,6 +29,7 @@ from syutil.jsonutil import encode_canonical_json
from contextlib import contextmanager
import logging
+import simplejson as json
logger = logging.getLogger(__name__)
@@ -393,3 +397,131 @@ class EventsStore(SQLBaseStore):
return self.runInteraction(
"have_events", f,
)
+
+ def _get_events(self, event_ids, check_redacted=True,
+ get_prev_content=False):
+ return self.runInteraction(
+ "_get_events", self._get_events_txn, event_ids,
+ check_redacted=check_redacted, get_prev_content=get_prev_content,
+ )
+
+ def _get_events_txn(self, txn, event_ids, check_redacted=True,
+ get_prev_content=False):
+ if not event_ids:
+ return []
+
+ events = [
+ self._get_event_txn(
+ txn, event_id,
+ check_redacted=check_redacted,
+ get_prev_content=get_prev_content
+ )
+ for event_id in event_ids
+ ]
+
+ return [e for e in events if e]
+
+ def _invalidate_get_event_cache(self, event_id):
+ for check_redacted in (False, True):
+ for get_prev_content in (False, True):
+ self._get_event_cache.invalidate(event_id, check_redacted,
+ get_prev_content)
+
+ def _get_event_txn(self, txn, event_id, check_redacted=True,
+ get_prev_content=False, allow_rejected=False):
+
+ try:
+ ret = self._get_event_cache.get(event_id, check_redacted, get_prev_content)
+
+ if allow_rejected or not ret.rejected_reason:
+ return ret
+ else:
+ return None
+ except KeyError:
+ pass
+
+ sql = (
+ "SELECT e.internal_metadata, e.json, r.event_id, rej.reason "
+ "FROM event_json as e "
+ "LEFT JOIN redactions as r ON e.event_id = r.redacts "
+ "LEFT JOIN rejections as rej on rej.event_id = e.event_id "
+ "WHERE e.event_id = ? "
+ "LIMIT 1 "
+ )
+
+ txn.execute(sql, (event_id,))
+
+ res = txn.fetchone()
+
+ if not res:
+ return None
+
+ internal_metadata, js, redacted, rejected_reason = res
+
+ result = self._get_event_from_row_txn(
+ txn, internal_metadata, js, redacted,
+ check_redacted=check_redacted,
+ get_prev_content=get_prev_content,
+ rejected_reason=rejected_reason,
+ )
+ self._get_event_cache.prefill(event_id, check_redacted, get_prev_content, result)
+
+ if allow_rejected or not rejected_reason:
+ return result
+ else:
+ return None
+
+ def _get_event_from_row_txn(self, txn, internal_metadata, js, redacted,
+ check_redacted=True, get_prev_content=False,
+ rejected_reason=None):
+
+ d = json.loads(js)
+ internal_metadata = json.loads(internal_metadata)
+
+ ev = FrozenEvent(
+ d,
+ internal_metadata_dict=internal_metadata,
+ rejected_reason=rejected_reason,
+ )
+
+ if check_redacted and redacted:
+ ev = prune_event(ev)
+
+ ev.unsigned["redacted_by"] = redacted
+ # Get the redaction event.
+
+ because = self._get_event_txn(
+ txn,
+ redacted,
+ check_redacted=False
+ )
+
+ if because:
+ ev.unsigned["redacted_because"] = because
+
+ if get_prev_content and "replaces_state" in ev.unsigned:
+ prev = self._get_event_txn(
+ txn,
+ ev.unsigned["replaces_state"],
+ get_prev_content=False,
+ )
+ if prev:
+ ev.unsigned["prev_content"] = prev.get_dict()["content"]
+
+ return ev
+
+ def _parse_events(self, rows):
+ return self.runInteraction(
+ "_parse_events", self._parse_events_txn, rows
+ )
+
+ def _parse_events_txn(self, txn, rows):
+ event_ids = [r["event_id"] for r in rows]
+
+ return self._get_events_txn(txn, event_ids)
+
+ def _has_been_redacted_txn(self, txn, event):
+ sql = "SELECT event_id FROM redactions WHERE redacts = ?"
+ txn.execute(sql, (event.event_id,))
+ result = txn.fetchone()
+ return result[0] if result else None
|