diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 1304219e86..d2a010bd88 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -15,20 +15,36 @@
from _base import SQLBaseStore, _RollbackButIsFineException
-from twisted.internet import defer
+from twisted.internet import defer, reactor
+from synapse.events import FrozenEvent
+from synapse.events.utils import prune_event
+
+from synapse.util.logcontext import preserve_context_over_deferred
from synapse.util.logutils import log_function
from synapse.api.constants import EventTypes
from synapse.crypto.event_signing import compute_event_reference_hash
from syutil.base64util import decode_base64
from syutil.jsonutil import encode_canonical_json
+from contextlib import contextmanager
import logging
+import simplejson as json
logger = logging.getLogger(__name__)
+# These values are used in the `enqueus_event` and `_do_fetch` methods to
+# control how we batch/bulk fetch events from the database.
+# The values are plucked out of thing air to make initial sync run faster
+# on jki.re
+# TODO: Make these configurable.
+EVENT_QUEUE_THREADS = 3 # Max number of threads that will fetch events
+EVENT_QUEUE_ITERATIONS = 3 # No. times we block waiting for requests for events
+EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events
+
+
class EventsStore(SQLBaseStore):
@defer.inlineCallbacks
@log_function
@@ -41,20 +57,32 @@ class EventsStore(SQLBaseStore):
self.min_token -= 1
stream_ordering = self.min_token
+ if stream_ordering is None:
+ stream_ordering_manager = yield self._stream_id_gen.get_next(self)
+ else:
+ @contextmanager
+ def stream_ordering_manager():
+ yield stream_ordering
+ stream_ordering_manager = stream_ordering_manager()
+
try:
- yield self.runInteraction(
- "persist_event",
- self._persist_event_txn,
- event=event,
- context=context,
- backfilled=backfilled,
- stream_ordering=stream_ordering,
- is_new_state=is_new_state,
- current_state=current_state,
- )
+ with stream_ordering_manager as stream_ordering:
+ yield self.runInteraction(
+ "persist_event",
+ self._persist_event_txn,
+ event=event,
+ context=context,
+ backfilled=backfilled,
+ stream_ordering=stream_ordering,
+ is_new_state=is_new_state,
+ current_state=current_state,
+ )
except _RollbackButIsFineException:
pass
+ max_persisted_id = yield self._stream_id_gen.get_max_token(self)
+ defer.returnValue((stream_ordering, max_persisted_id))
+
@defer.inlineCallbacks
def get_event(self, event_id, check_redacted=True,
get_prev_content=False, allow_rejected=False,
@@ -74,18 +102,17 @@ class EventsStore(SQLBaseStore):
Returns:
Deferred : A FrozenEvent.
"""
- event = yield self.runInteraction(
- "get_event", self._get_event_txn,
- event_id,
+ events = yield self._get_events(
+ [event_id],
check_redacted=check_redacted,
get_prev_content=get_prev_content,
allow_rejected=allow_rejected,
)
- if not event and not allow_none:
+ if not events and not allow_none:
raise RuntimeError("Could not find event %s" % (event_id,))
- defer.returnValue(event)
+ defer.returnValue(events[0] if events else None)
@log_function
def _persist_event_txn(self, txn, event, context, backfilled,
@@ -95,15 +122,6 @@ class EventsStore(SQLBaseStore):
# Remove the any existing cache entries for the event_id
txn.call_after(self._invalidate_get_event_cache, event.event_id)
- if stream_ordering is None:
- with self._stream_id_gen.get_next_txn(txn) as stream_ordering:
- return self._persist_event_txn(
- txn, event, context, backfilled,
- stream_ordering=stream_ordering,
- is_new_state=is_new_state,
- current_state=current_state,
- )
-
# We purposefully do this first since if we include a `current_state`
# key, we *want* to update the `current_state_events` table
if current_state:
@@ -134,19 +152,17 @@ class EventsStore(SQLBaseStore):
outlier = event.internal_metadata.is_outlier()
if not outlier:
- self._store_state_groups_txn(txn, event, context)
-
self._update_min_depth_for_room_txn(
txn,
event.room_id,
event.depth
)
- have_persisted = self._simple_select_one_onecol_txn(
+ have_persisted = self._simple_select_one_txn(
txn,
- table="event_json",
+ table="events",
keyvalues={"event_id": event.event_id},
- retcol="event_id",
+ retcols=["event_id", "outlier"],
allow_none=True,
)
@@ -161,7 +177,9 @@ class EventsStore(SQLBaseStore):
# if we are persisting an event that we had persisted as an outlier,
# but is no longer one.
if have_persisted:
- if not outlier:
+ if not outlier and have_persisted["outlier"]:
+ self._store_state_groups_txn(txn, event, context)
+
sql = (
"UPDATE event_json SET internal_metadata = ?"
" WHERE event_id = ?"
@@ -181,6 +199,9 @@ class EventsStore(SQLBaseStore):
)
return
+ if not outlier:
+ self._store_state_groups_txn(txn, event, context)
+
self._handle_prev_events(
txn,
outlier=outlier,
@@ -400,3 +421,407 @@ class EventsStore(SQLBaseStore):
return self.runInteraction(
"have_events", f,
)
+
+ @defer.inlineCallbacks
+ def _get_events(self, event_ids, check_redacted=True,
+ get_prev_content=False, allow_rejected=False):
+ if not event_ids:
+ defer.returnValue([])
+
+ event_map = self._get_events_from_cache(
+ event_ids,
+ check_redacted=check_redacted,
+ get_prev_content=get_prev_content,
+ allow_rejected=allow_rejected,
+ )
+
+ missing_events_ids = [e for e in event_ids if e not in event_map]
+
+ if not missing_events_ids:
+ defer.returnValue([
+ event_map[e_id] for e_id in event_ids
+ if e_id in event_map and event_map[e_id]
+ ])
+
+ missing_events = yield self._enqueue_events(
+ missing_events_ids,
+ check_redacted=check_redacted,
+ get_prev_content=get_prev_content,
+ allow_rejected=allow_rejected,
+ )
+
+ event_map.update(missing_events)
+
+ defer.returnValue([
+ event_map[e_id] for e_id in event_ids
+ if e_id in event_map and event_map[e_id]
+ ])
+
+ def _get_events_txn(self, txn, event_ids, check_redacted=True,
+ get_prev_content=False, allow_rejected=False):
+ if not event_ids:
+ return []
+
+ event_map = self._get_events_from_cache(
+ event_ids,
+ check_redacted=check_redacted,
+ get_prev_content=get_prev_content,
+ allow_rejected=allow_rejected,
+ )
+
+ missing_events_ids = [e for e in event_ids if e not in event_map]
+
+ if not missing_events_ids:
+ return [
+ event_map[e_id] for e_id in event_ids
+ if e_id in event_map and event_map[e_id]
+ ]
+
+ missing_events = self._fetch_events_txn(
+ txn,
+ missing_events_ids,
+ check_redacted=check_redacted,
+ get_prev_content=get_prev_content,
+ allow_rejected=allow_rejected,
+ )
+
+ event_map.update(missing_events)
+
+ return [
+ event_map[e_id] for e_id in event_ids
+ if e_id in event_map and event_map[e_id]
+ ]
+
+ def _invalidate_get_event_cache(self, event_id):
+ for check_redacted in (False, True):
+ for get_prev_content in (False, True):
+ self._get_event_cache.invalidate(event_id, check_redacted,
+ get_prev_content)
+
+ def _get_event_txn(self, txn, event_id, check_redacted=True,
+ get_prev_content=False, allow_rejected=False):
+
+ events = self._get_events_txn(
+ txn, [event_id],
+ check_redacted=check_redacted,
+ get_prev_content=get_prev_content,
+ allow_rejected=allow_rejected,
+ )
+
+ return events[0] if events else None
+
+ def _get_events_from_cache(self, events, check_redacted, get_prev_content,
+ allow_rejected):
+ event_map = {}
+
+ for event_id in events:
+ try:
+ ret = self._get_event_cache.get(
+ event_id, check_redacted, get_prev_content
+ )
+
+ if allow_rejected or not ret.rejected_reason:
+ event_map[event_id] = ret
+ else:
+ event_map[event_id] = None
+ except KeyError:
+ pass
+
+ return event_map
+
+ def _do_fetch(self, conn):
+ """Takes a database connection and waits for requests for events from
+ the _event_fetch_list queue.
+ """
+ event_list = []
+ i = 0
+ while True:
+ try:
+ with self._event_fetch_lock:
+ event_list = self._event_fetch_list
+ self._event_fetch_list = []
+
+ if not event_list:
+ single_threaded = self.database_engine.single_threaded
+ if single_threaded or i > EVENT_QUEUE_ITERATIONS:
+ self._event_fetch_ongoing -= 1
+ return
+ else:
+ self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
+ i += 1
+ continue
+ i = 0
+
+ event_id_lists = zip(*event_list)[0]
+ event_ids = [
+ item for sublist in event_id_lists for item in sublist
+ ]
+
+ rows = self._new_transaction(
+ conn, "do_fetch", [], self._fetch_event_rows, event_ids
+ )
+
+ row_dict = {
+ r["event_id"]: r
+ for r in rows
+ }
+
+ # We only want to resolve deferreds from the main thread
+ def fire(lst, res):
+ for ids, d in lst:
+ if not d.called:
+ try:
+ d.callback([
+ res[i]
+ for i in ids
+ if i in res
+ ])
+ except:
+ logger.exception("Failed to callback")
+ reactor.callFromThread(fire, event_list, row_dict)
+ except Exception as e:
+ logger.exception("do_fetch")
+
+ # We only want to resolve deferreds from the main thread
+ def fire(evs):
+ for _, d in evs:
+ if not d.called:
+ d.errback(e)
+
+ if event_list:
+ reactor.callFromThread(fire, event_list)
+
+ @defer.inlineCallbacks
+ def _enqueue_events(self, events, check_redacted=True,
+ get_prev_content=False, allow_rejected=False):
+ """Fetches events from the database using the _event_fetch_list. This
+ allows batch and bulk fetching of events - it allows us to fetch events
+ without having to create a new transaction for each request for events.
+ """
+ if not events:
+ defer.returnValue({})
+
+ events_d = defer.Deferred()
+ with self._event_fetch_lock:
+ self._event_fetch_list.append(
+ (events, events_d)
+ )
+
+ self._event_fetch_lock.notify()
+
+ if self._event_fetch_ongoing < EVENT_QUEUE_THREADS:
+ self._event_fetch_ongoing += 1
+ should_start = True
+ else:
+ should_start = False
+
+ if should_start:
+ self.runWithConnection(
+ self._do_fetch
+ )
+
+ rows = yield preserve_context_over_deferred(events_d)
+
+ if not allow_rejected:
+ rows[:] = [r for r in rows if not r["rejects"]]
+
+ res = yield defer.gatherResults(
+ [
+ self._get_event_from_row(
+ row["internal_metadata"], row["json"], row["redacts"],
+ check_redacted=check_redacted,
+ get_prev_content=get_prev_content,
+ rejected_reason=row["rejects"],
+ )
+ for row in rows
+ ],
+ consumeErrors=True
+ )
+
+ defer.returnValue({
+ e.event_id: e
+ for e in res if e
+ })
+
+ def _fetch_event_rows(self, txn, events):
+ rows = []
+ N = 200
+ for i in range(1 + len(events) / N):
+ evs = events[i*N:(i + 1)*N]
+ if not evs:
+ break
+
+ sql = (
+ "SELECT "
+ " e.event_id as event_id, "
+ " e.internal_metadata,"
+ " e.json,"
+ " r.redacts as redacts,"
+ " rej.event_id as rejects "
+ " FROM event_json as e"
+ " LEFT JOIN rejections as rej USING (event_id)"
+ " LEFT JOIN redactions as r ON e.event_id = r.redacts"
+ " WHERE e.event_id IN (%s)"
+ ) % (",".join(["?"]*len(evs)),)
+
+ txn.execute(sql, evs)
+ rows.extend(self.cursor_to_dict(txn))
+
+ return rows
+
+ def _fetch_events_txn(self, txn, events, check_redacted=True,
+ get_prev_content=False, allow_rejected=False):
+ if not events:
+ return {}
+
+ rows = self._fetch_event_rows(
+ txn, events,
+ )
+
+ if not allow_rejected:
+ rows[:] = [r for r in rows if not r["rejects"]]
+
+ res = [
+ self._get_event_from_row_txn(
+ txn,
+ row["internal_metadata"], row["json"], row["redacts"],
+ check_redacted=check_redacted,
+ get_prev_content=get_prev_content,
+ rejected_reason=row["rejects"],
+ )
+ for row in rows
+ ]
+
+ return {
+ r.event_id: r
+ for r in res
+ }
+
+ @defer.inlineCallbacks
+ def _get_event_from_row(self, internal_metadata, js, redacted,
+ check_redacted=True, get_prev_content=False,
+ rejected_reason=None):
+ d = json.loads(js)
+ internal_metadata = json.loads(internal_metadata)
+
+ if rejected_reason:
+ rejected_reason = yield self._simple_select_one_onecol(
+ table="rejections",
+ keyvalues={"event_id": rejected_reason},
+ retcol="reason",
+ desc="_get_event_from_row",
+ )
+
+ ev = FrozenEvent(
+ d,
+ internal_metadata_dict=internal_metadata,
+ rejected_reason=rejected_reason,
+ )
+
+ if check_redacted and redacted:
+ ev = prune_event(ev)
+
+ redaction_id = yield self._simple_select_one_onecol(
+ table="redactions",
+ keyvalues={"redacts": ev.event_id},
+ retcol="event_id",
+ desc="_get_event_from_row",
+ )
+
+ ev.unsigned["redacted_by"] = redaction_id
+ # Get the redaction event.
+
+ because = yield self.get_event(
+ redaction_id,
+ check_redacted=False
+ )
+
+ if because:
+ ev.unsigned["redacted_because"] = because
+
+ if get_prev_content and "replaces_state" in ev.unsigned:
+ prev = yield self.get_event(
+ ev.unsigned["replaces_state"],
+ get_prev_content=False,
+ )
+ if prev:
+ ev.unsigned["prev_content"] = prev.get_dict()["content"]
+
+ self._get_event_cache.prefill(
+ ev.event_id, check_redacted, get_prev_content, ev
+ )
+
+ defer.returnValue(ev)
+
+ def _get_event_from_row_txn(self, txn, internal_metadata, js, redacted,
+ check_redacted=True, get_prev_content=False,
+ rejected_reason=None):
+ d = json.loads(js)
+ internal_metadata = json.loads(internal_metadata)
+
+ if rejected_reason:
+ rejected_reason = self._simple_select_one_onecol_txn(
+ txn,
+ table="rejections",
+ keyvalues={"event_id": rejected_reason},
+ retcol="reason",
+ )
+
+ ev = FrozenEvent(
+ d,
+ internal_metadata_dict=internal_metadata,
+ rejected_reason=rejected_reason,
+ )
+
+ if check_redacted and redacted:
+ ev = prune_event(ev)
+
+ redaction_id = self._simple_select_one_onecol_txn(
+ txn,
+ table="redactions",
+ keyvalues={"redacts": ev.event_id},
+ retcol="event_id",
+ )
+
+ ev.unsigned["redacted_by"] = redaction_id
+ # Get the redaction event.
+
+ because = self._get_event_txn(
+ txn,
+ redaction_id,
+ check_redacted=False
+ )
+
+ if because:
+ ev.unsigned["redacted_because"] = because
+
+ if get_prev_content and "replaces_state" in ev.unsigned:
+ prev = self._get_event_txn(
+ txn,
+ ev.unsigned["replaces_state"],
+ get_prev_content=False,
+ )
+ if prev:
+ ev.unsigned["prev_content"] = prev.get_dict()["content"]
+
+ self._get_event_cache.prefill(
+ ev.event_id, check_redacted, get_prev_content, ev
+ )
+
+ return ev
+
+ def _parse_events(self, rows):
+ return self.runInteraction(
+ "_parse_events", self._parse_events_txn, rows
+ )
+
+ def _parse_events_txn(self, txn, rows):
+ event_ids = [r["event_id"] for r in rows]
+
+ return self._get_events_txn(txn, event_ids)
+
+ def _has_been_redacted_txn(self, txn, event):
+ sql = "SELECT event_id FROM redactions WHERE redacts = ?"
+ txn.execute(sql, (event.event_id,))
+ result = txn.fetchone()
+ return result[0] if result else None
|