diff options
Diffstat (limited to 'synapse/storage/events.py')
-rw-r--r-- | synapse/storage/events.py | 455 |
1 files changed, 442 insertions, 13 deletions
diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 971f3211ac..d2a010bd88 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -15,8 +15,12 @@ from _base import SQLBaseStore, _RollbackButIsFineException -from twisted.internet import defer +from twisted.internet import defer, reactor +from synapse.events import FrozenEvent +from synapse.events.utils import prune_event + +from synapse.util.logcontext import preserve_context_over_deferred from synapse.util.logutils import log_function from synapse.api.constants import EventTypes from synapse.crypto.event_signing import compute_event_reference_hash @@ -26,10 +30,21 @@ from syutil.jsonutil import encode_canonical_json from contextlib import contextmanager import logging +import simplejson as json logger = logging.getLogger(__name__) +# These values are used in the `enqueus_event` and `_do_fetch` methods to +# control how we batch/bulk fetch events from the database. +# The values are plucked out of thing air to make initial sync run faster +# on jki.re +# TODO: Make these configurable. +EVENT_QUEUE_THREADS = 3 # Max number of threads that will fetch events +EVENT_QUEUE_ITERATIONS = 3 # No. times we block waiting for requests for events +EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events + + class EventsStore(SQLBaseStore): @defer.inlineCallbacks @log_function @@ -87,18 +102,17 @@ class EventsStore(SQLBaseStore): Returns: Deferred : A FrozenEvent. """ - event = yield self.runInteraction( - "get_event", self._get_event_txn, - event_id, + events = yield self._get_events( + [event_id], check_redacted=check_redacted, get_prev_content=get_prev_content, allow_rejected=allow_rejected, ) - if not event and not allow_none: + if not events and not allow_none: raise RuntimeError("Could not find event %s" % (event_id,)) - defer.returnValue(event) + defer.returnValue(events[0] if events else None) @log_function def _persist_event_txn(self, txn, event, context, backfilled, @@ -111,6 +125,12 @@ class EventsStore(SQLBaseStore): # We purposefully do this first since if we include a `current_state` # key, we *want* to update the `current_state_events` table if current_state: + txn.call_after(self.get_current_state_for_key.invalidate_all) + txn.call_after(self.get_rooms_for_user.invalidate_all) + txn.call_after(self.get_users_in_room.invalidate, event.room_id) + txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id) + txn.call_after(self.get_room_name_and_aliases, event.room_id) + self._simple_delete_txn( txn, table="current_state_events", @@ -118,13 +138,6 @@ class EventsStore(SQLBaseStore): ) for s in current_state: - if s.type == EventTypes.Member: - txn.call_after( - self.get_rooms_for_user.invalidate, s.state_key - ) - txn.call_after( - self.get_joined_hosts_for_room.invalidate, s.room_id - ) self._simple_insert_txn( txn, "current_state_events", @@ -342,6 +355,18 @@ class EventsStore(SQLBaseStore): ) if is_new_state and not context.rejected: + txn.call_after( + self.get_current_state_for_key.invalidate, + event.room_id, event.type, event.state_key + ) + + if (event.type == EventTypes.Name + or event.type == EventTypes.Aliases): + txn.call_after( + self.get_room_name_and_aliases.invalidate, + event.room_id + ) + self._simple_upsert_txn( txn, "current_state_events", @@ -396,3 +421,407 @@ class EventsStore(SQLBaseStore): return self.runInteraction( "have_events", f, ) + + @defer.inlineCallbacks + def _get_events(self, event_ids, check_redacted=True, + get_prev_content=False, allow_rejected=False): + if not event_ids: + defer.returnValue([]) + + event_map = self._get_events_from_cache( + event_ids, + check_redacted=check_redacted, + get_prev_content=get_prev_content, + allow_rejected=allow_rejected, + ) + + missing_events_ids = [e for e in event_ids if e not in event_map] + + if not missing_events_ids: + defer.returnValue([ + event_map[e_id] for e_id in event_ids + if e_id in event_map and event_map[e_id] + ]) + + missing_events = yield self._enqueue_events( + missing_events_ids, + check_redacted=check_redacted, + get_prev_content=get_prev_content, + allow_rejected=allow_rejected, + ) + + event_map.update(missing_events) + + defer.returnValue([ + event_map[e_id] for e_id in event_ids + if e_id in event_map and event_map[e_id] + ]) + + def _get_events_txn(self, txn, event_ids, check_redacted=True, + get_prev_content=False, allow_rejected=False): + if not event_ids: + return [] + + event_map = self._get_events_from_cache( + event_ids, + check_redacted=check_redacted, + get_prev_content=get_prev_content, + allow_rejected=allow_rejected, + ) + + missing_events_ids = [e for e in event_ids if e not in event_map] + + if not missing_events_ids: + return [ + event_map[e_id] for e_id in event_ids + if e_id in event_map and event_map[e_id] + ] + + missing_events = self._fetch_events_txn( + txn, + missing_events_ids, + check_redacted=check_redacted, + get_prev_content=get_prev_content, + allow_rejected=allow_rejected, + ) + + event_map.update(missing_events) + + return [ + event_map[e_id] for e_id in event_ids + if e_id in event_map and event_map[e_id] + ] + + def _invalidate_get_event_cache(self, event_id): + for check_redacted in (False, True): + for get_prev_content in (False, True): + self._get_event_cache.invalidate(event_id, check_redacted, + get_prev_content) + + def _get_event_txn(self, txn, event_id, check_redacted=True, + get_prev_content=False, allow_rejected=False): + + events = self._get_events_txn( + txn, [event_id], + check_redacted=check_redacted, + get_prev_content=get_prev_content, + allow_rejected=allow_rejected, + ) + + return events[0] if events else None + + def _get_events_from_cache(self, events, check_redacted, get_prev_content, + allow_rejected): + event_map = {} + + for event_id in events: + try: + ret = self._get_event_cache.get( + event_id, check_redacted, get_prev_content + ) + + if allow_rejected or not ret.rejected_reason: + event_map[event_id] = ret + else: + event_map[event_id] = None + except KeyError: + pass + + return event_map + + def _do_fetch(self, conn): + """Takes a database connection and waits for requests for events from + the _event_fetch_list queue. + """ + event_list = [] + i = 0 + while True: + try: + with self._event_fetch_lock: + event_list = self._event_fetch_list + self._event_fetch_list = [] + + if not event_list: + single_threaded = self.database_engine.single_threaded + if single_threaded or i > EVENT_QUEUE_ITERATIONS: + self._event_fetch_ongoing -= 1 + return + else: + self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S) + i += 1 + continue + i = 0 + + event_id_lists = zip(*event_list)[0] + event_ids = [ + item for sublist in event_id_lists for item in sublist + ] + + rows = self._new_transaction( + conn, "do_fetch", [], self._fetch_event_rows, event_ids + ) + + row_dict = { + r["event_id"]: r + for r in rows + } + + # We only want to resolve deferreds from the main thread + def fire(lst, res): + for ids, d in lst: + if not d.called: + try: + d.callback([ + res[i] + for i in ids + if i in res + ]) + except: + logger.exception("Failed to callback") + reactor.callFromThread(fire, event_list, row_dict) + except Exception as e: + logger.exception("do_fetch") + + # We only want to resolve deferreds from the main thread + def fire(evs): + for _, d in evs: + if not d.called: + d.errback(e) + + if event_list: + reactor.callFromThread(fire, event_list) + + @defer.inlineCallbacks + def _enqueue_events(self, events, check_redacted=True, + get_prev_content=False, allow_rejected=False): + """Fetches events from the database using the _event_fetch_list. This + allows batch and bulk fetching of events - it allows us to fetch events + without having to create a new transaction for each request for events. + """ + if not events: + defer.returnValue({}) + + events_d = defer.Deferred() + with self._event_fetch_lock: + self._event_fetch_list.append( + (events, events_d) + ) + + self._event_fetch_lock.notify() + + if self._event_fetch_ongoing < EVENT_QUEUE_THREADS: + self._event_fetch_ongoing += 1 + should_start = True + else: + should_start = False + + if should_start: + self.runWithConnection( + self._do_fetch + ) + + rows = yield preserve_context_over_deferred(events_d) + + if not allow_rejected: + rows[:] = [r for r in rows if not r["rejects"]] + + res = yield defer.gatherResults( + [ + self._get_event_from_row( + row["internal_metadata"], row["json"], row["redacts"], + check_redacted=check_redacted, + get_prev_content=get_prev_content, + rejected_reason=row["rejects"], + ) + for row in rows + ], + consumeErrors=True + ) + + defer.returnValue({ + e.event_id: e + for e in res if e + }) + + def _fetch_event_rows(self, txn, events): + rows = [] + N = 200 + for i in range(1 + len(events) / N): + evs = events[i*N:(i + 1)*N] + if not evs: + break + + sql = ( + "SELECT " + " e.event_id as event_id, " + " e.internal_metadata," + " e.json," + " r.redacts as redacts," + " rej.event_id as rejects " + " FROM event_json as e" + " LEFT JOIN rejections as rej USING (event_id)" + " LEFT JOIN redactions as r ON e.event_id = r.redacts" + " WHERE e.event_id IN (%s)" + ) % (",".join(["?"]*len(evs)),) + + txn.execute(sql, evs) + rows.extend(self.cursor_to_dict(txn)) + + return rows + + def _fetch_events_txn(self, txn, events, check_redacted=True, + get_prev_content=False, allow_rejected=False): + if not events: + return {} + + rows = self._fetch_event_rows( + txn, events, + ) + + if not allow_rejected: + rows[:] = [r for r in rows if not r["rejects"]] + + res = [ + self._get_event_from_row_txn( + txn, + row["internal_metadata"], row["json"], row["redacts"], + check_redacted=check_redacted, + get_prev_content=get_prev_content, + rejected_reason=row["rejects"], + ) + for row in rows + ] + + return { + r.event_id: r + for r in res + } + + @defer.inlineCallbacks + def _get_event_from_row(self, internal_metadata, js, redacted, + check_redacted=True, get_prev_content=False, + rejected_reason=None): + d = json.loads(js) + internal_metadata = json.loads(internal_metadata) + + if rejected_reason: + rejected_reason = yield self._simple_select_one_onecol( + table="rejections", + keyvalues={"event_id": rejected_reason}, + retcol="reason", + desc="_get_event_from_row", + ) + + ev = FrozenEvent( + d, + internal_metadata_dict=internal_metadata, + rejected_reason=rejected_reason, + ) + + if check_redacted and redacted: + ev = prune_event(ev) + + redaction_id = yield self._simple_select_one_onecol( + table="redactions", + keyvalues={"redacts": ev.event_id}, + retcol="event_id", + desc="_get_event_from_row", + ) + + ev.unsigned["redacted_by"] = redaction_id + # Get the redaction event. + + because = yield self.get_event( + redaction_id, + check_redacted=False + ) + + if because: + ev.unsigned["redacted_because"] = because + + if get_prev_content and "replaces_state" in ev.unsigned: + prev = yield self.get_event( + ev.unsigned["replaces_state"], + get_prev_content=False, + ) + if prev: + ev.unsigned["prev_content"] = prev.get_dict()["content"] + + self._get_event_cache.prefill( + ev.event_id, check_redacted, get_prev_content, ev + ) + + defer.returnValue(ev) + + def _get_event_from_row_txn(self, txn, internal_metadata, js, redacted, + check_redacted=True, get_prev_content=False, + rejected_reason=None): + d = json.loads(js) + internal_metadata = json.loads(internal_metadata) + + if rejected_reason: + rejected_reason = self._simple_select_one_onecol_txn( + txn, + table="rejections", + keyvalues={"event_id": rejected_reason}, + retcol="reason", + ) + + ev = FrozenEvent( + d, + internal_metadata_dict=internal_metadata, + rejected_reason=rejected_reason, + ) + + if check_redacted and redacted: + ev = prune_event(ev) + + redaction_id = self._simple_select_one_onecol_txn( + txn, + table="redactions", + keyvalues={"redacts": ev.event_id}, + retcol="event_id", + ) + + ev.unsigned["redacted_by"] = redaction_id + # Get the redaction event. + + because = self._get_event_txn( + txn, + redaction_id, + check_redacted=False + ) + + if because: + ev.unsigned["redacted_because"] = because + + if get_prev_content and "replaces_state" in ev.unsigned: + prev = self._get_event_txn( + txn, + ev.unsigned["replaces_state"], + get_prev_content=False, + ) + if prev: + ev.unsigned["prev_content"] = prev.get_dict()["content"] + + self._get_event_cache.prefill( + ev.event_id, check_redacted, get_prev_content, ev + ) + + return ev + + def _parse_events(self, rows): + return self.runInteraction( + "_parse_events", self._parse_events_txn, rows + ) + + def _parse_events_txn(self, txn, rows): + event_ids = [r["event_id"] for r in rows] + + return self._get_events_txn(txn, event_ids) + + def _has_been_redacted_txn(self, txn, event): + sql = "SELECT event_id FROM redactions WHERE redacts = ?" + txn.execute(sql, (event.event_id,)) + result = txn.fetchone() + return result[0] if result else None |