diff options
Diffstat (limited to 'synapse/storage/events.py')
-rw-r--r-- | synapse/storage/events.py | 246 |
1 files changed, 246 insertions, 0 deletions
diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 9242b0a84e..f960ef8350 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -17,6 +17,10 @@ from _base import SQLBaseStore, _RollbackButIsFineException from twisted.internet import defer +from synapse.events import FrozenEvent +from synapse.events.utils import prune_event +from synapse.util import unwrap_deferred + from synapse.util.logutils import log_function from synapse.api.constants import EventTypes from synapse.crypto.event_signing import compute_event_reference_hash @@ -26,6 +30,7 @@ from syutil.jsonutil import encode_canonical_json from contextlib import contextmanager import logging +import simplejson as json logger = logging.getLogger(__name__) @@ -393,3 +398,244 @@ class EventsStore(SQLBaseStore): return self.runInteraction( "have_events", f, ) + + @defer.inlineCallbacks + def _get_events(self, event_ids, check_redacted=True, + get_prev_content=False, allow_rejected=False, txn=None): + if not event_ids: + defer.returnValue([]) + + event_map = self._get_events_from_cache( + event_ids, + check_redacted=check_redacted, + get_prev_content=get_prev_content, + allow_rejected=allow_rejected, + ) + + missing_events_ids = [e for e in event_ids if e not in event_map] + + if not missing_events_ids: + defer.returnValue([ + event_map[e_id] for e_id in event_ids + if e_id in event_map and event_map[e_id] + ]) + + if not txn: + missing_events = yield self.runInteraction( + "_get_events", + self._fetch_events_txn, + missing_events_ids, + check_redacted=check_redacted, + get_prev_content=get_prev_content, + allow_rejected=allow_rejected, + ) + else: + missing_events = yield self._fetch_events( + txn, + missing_events_ids, + check_redacted=check_redacted, + get_prev_content=get_prev_content, + allow_rejected=allow_rejected, + ) + + event_map.update(missing_events) + + defer.returnValue([ + event_map[e_id] for e_id in event_ids + if e_id in event_map and event_map[e_id] + ]) + + + def _get_events_txn(self, txn, event_ids, check_redacted=True, + get_prev_content=False, allow_rejected=False): + return unwrap_deferred(self._get_events( + event_ids, + check_redacted=check_redacted, + get_prev_content=get_prev_content, + allow_rejected=allow_rejected, + txn=txn, + )) + + def _invalidate_get_event_cache(self, event_id): + for check_redacted in (False, True): + for get_prev_content in (False, True): + self._get_event_cache.invalidate(event_id, check_redacted, + get_prev_content) + + def _get_event_txn(self, txn, event_id, check_redacted=True, + get_prev_content=False, allow_rejected=False): + + events = self._get_events_txn( + txn, [event_id], + check_redacted=check_redacted, + get_prev_content=get_prev_content, + allow_rejected=allow_rejected, + ) + + return events[0] if events else None + + def _get_events_from_cache(self, events, check_redacted, get_prev_content, + allow_rejected): + event_map = {} + + for event_id in events: + try: + ret = self._get_event_cache.get( + event_id, check_redacted, get_prev_content + ) + + if allow_rejected or not ret.rejected_reason: + event_map[event_id] = ret + else: + event_map[event_id] = None + except KeyError: + pass + + return event_map + + def _fetch_events_txn(self, txn, events, check_redacted=True, + get_prev_content=False, allow_rejected=False): + return unwrap_deferred(self._fetch_events( + txn, events, + check_redacted=check_redacted, + get_prev_content=get_prev_content, + allow_rejected=allow_rejected, + )) + + @defer.inlineCallbacks + def _fetch_events(self, txn, events, check_redacted=True, + get_prev_content=False, allow_rejected=False): + if not events: + defer.returnValue({}) + + rows = [] + N = 200 + for i in range(1 + len(events) / N): + evs = events[i*N:(i + 1)*N] + if not evs: + break + + sql = ( + "SELECT e.internal_metadata, e.json, r.redacts, rej.event_id " + " FROM event_json as e" + " LEFT JOIN rejections as rej USING (event_id)" + " LEFT JOIN redactions as r ON e.event_id = r.redacts" + " WHERE e.event_id IN (%s)" + ) % (",".join(["?"]*len(evs)),) + + if txn: + txn.execute(sql, evs) + rows.extend(txn.fetchall()) + else: + res = yield self._execute("_fetch_events", None, sql, *evs) + rows.extend(res) + + res = yield defer.gatherResults( + [ + defer.maybeDeferred( + self._get_event_from_row, + txn, + row[0], row[1], row[2], + check_redacted=check_redacted, + get_prev_content=get_prev_content, + rejected_reason=row[3], + ) + for row in rows + ], + consumeErrors=True, + ) + + for e in res: + self._get_event_cache.prefill( + e.event_id, check_redacted, get_prev_content, e + ) + + defer.returnValue({ + e.event_id: e + for e in res if e + }) + + @defer.inlineCallbacks + def _get_event_from_row(self, txn, internal_metadata, js, redacted, + check_redacted=True, get_prev_content=False, + rejected_reason=None): + d = json.loads(js) + internal_metadata = json.loads(internal_metadata) + + def select(txn, *args, **kwargs): + if txn: + return self._simple_select_one_onecol_txn(txn, *args, **kwargs) + else: + return self._simple_select_one_onecol( + *args, + desc="_get_event_from_row", **kwargs + ) + + def get_event(txn, *args, **kwargs): + if txn: + return self._get_event_txn(txn, *args, **kwargs) + else: + return self.get_event(*args, **kwargs) + + if rejected_reason: + rejected_reason = yield select( + txn, + table="rejections", + keyvalues={"event_id": rejected_reason}, + retcol="reason", + ) + + ev = FrozenEvent( + d, + internal_metadata_dict=internal_metadata, + rejected_reason=rejected_reason, + ) + + if check_redacted and redacted: + ev = prune_event(ev) + + redaction_id = yield select( + txn, + table="redactions", + keyvalues={"redacts": ev.event_id}, + retcol="event_id", + ) + + ev.unsigned["redacted_by"] = redaction_id + # Get the redaction event. + + because = yield get_event( + txn, + redaction_id, + check_redacted=False + ) + + if because: + ev.unsigned["redacted_because"] = because + + if get_prev_content and "replaces_state" in ev.unsigned: + prev = yield get_event( + txn, + ev.unsigned["replaces_state"], + get_prev_content=False, + ) + if prev: + ev.unsigned["prev_content"] = prev.get_dict()["content"] + + defer.returnValue(ev) + + def _parse_events(self, rows): + return self.runInteraction( + "_parse_events", self._parse_events_txn, rows + ) + + def _parse_events_txn(self, txn, rows): + event_ids = [r["event_id"] for r in rows] + + return self._get_events_txn(txn, event_ids) + + def _has_been_redacted_txn(self, txn, event): + sql = "SELECT event_id FROM redactions WHERE redacts = ?" + txn.execute(sql, (event.event_id,)) + result = txn.fetchone() + return result[0] if result else None |