diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index a8f8989e38..c20ff3a572 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -299,6 +299,10 @@ class SQLBaseStore(object):
self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True,
max_entries=hs.config.event_cache_size)
+ self._event_fetch_lock = threading.Lock()
+ self._event_fetch_list = []
+ self._event_fetch_ongoing = False
+
self.database_engine = hs.database_engine
self._stream_id_gen = StreamIdGenerator()
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 0aa4e0d445..be88328ce5 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -15,7 +15,7 @@
from _base import SQLBaseStore, _RollbackButIsFineException
-from twisted.internet import defer
+from twisted.internet import defer, reactor
from synapse.events import FrozenEvent
from synapse.events.utils import prune_event
@@ -89,18 +89,17 @@ class EventsStore(SQLBaseStore):
Returns:
Deferred : A FrozenEvent.
"""
- event = yield self.runInteraction(
- "get_event", self._get_event_txn,
- event_id,
- check_redacted=check_redacted,
- get_prev_content=get_prev_content,
- allow_rejected=allow_rejected,
+ events = yield self._get_events(
+ [event_id],
+ check_redacted=True,
+ get_prev_content=False,
+ allow_rejected=False,
)
- if not event and not allow_none:
+ if not events and not allow_none:
raise RuntimeError("Could not find event %s" % (event_id,))
- defer.returnValue(event)
+ defer.returnValue(events[0] if events else None)
@log_function
def _persist_event_txn(self, txn, event, context, backfilled,
@@ -420,13 +419,21 @@ class EventsStore(SQLBaseStore):
if e_id in event_map and event_map[e_id]
])
- missing_events = yield self._fetch_events(
- txn,
- missing_events_ids,
- check_redacted=check_redacted,
- get_prev_content=get_prev_content,
- allow_rejected=allow_rejected,
- )
+ if not txn:
+ missing_events = yield self._enqueue_events(
+ missing_events_ids,
+ check_redacted=check_redacted,
+ get_prev_content=get_prev_content,
+ allow_rejected=allow_rejected,
+ )
+ else:
+ missing_events = self._fetch_events_txn(
+ txn,
+ missing_events_ids,
+ check_redacted=check_redacted,
+ get_prev_content=get_prev_content,
+ allow_rejected=allow_rejected,
+ )
event_map.update(missing_events)
@@ -492,11 +499,82 @@ class EventsStore(SQLBaseStore):
))
@defer.inlineCallbacks
- def _fetch_events(self, txn, events, check_redacted=True,
- get_prev_content=False, allow_rejected=False):
+ def _enqueue_events(self, events, check_redacted=True,
+ get_prev_content=False, allow_rejected=False):
if not events:
defer.returnValue({})
+ def do_fetch(txn):
+ event_list = []
+ try:
+ with self._event_fetch_lock:
+ event_list = self._event_fetch_list
+ self._event_fetch_list = []
+
+ if not event_list:
+ return
+
+ event_id_lists = zip(*event_list)[0]
+ event_ids = [
+ item for sublist in event_id_lists for item in sublist
+ ]
+ rows = self._fetch_event_rows(txn, event_ids)
+
+ row_dict = {
+ r["event_id"]: r
+ for r in rows
+ }
+
+ for ids, d in event_list:
+ d.callback(
+ [
+ row_dict[i] for i in ids
+ if i in row_dict
+ ]
+ )
+ except Exception as e:
+ for _, d in event_list:
+ try:
+ reactor.callFromThread(d.errback, e)
+ except:
+ pass
+ finally:
+ with self._event_fetch_lock:
+ self._event_fetch_ongoing = False
+
+ def cb(rows):
+ return defer.gatherResults([
+ self._get_event_from_row(
+ None,
+ row["internal_metadata"], row["json"], row["redacts"],
+ check_redacted=check_redacted,
+ get_prev_content=get_prev_content,
+ rejected_reason=row["rejects"],
+ )
+ for row in rows
+ ])
+
+ d = defer.Deferred()
+ d.addCallback(cb)
+ with self._event_fetch_lock:
+ self._event_fetch_list.append(
+ (events, d)
+ )
+
+ if not self._event_fetch_ongoing:
+ self.runInteraction(
+ "do_fetch",
+ do_fetch
+ )
+
+ res = yield d
+
+ defer.returnValue({
+ e.event_id: e
+ for e in res if e
+ })
+
+ def _fetch_event_rows(self, txn, events):
rows = []
N = 200
for i in range(1 + len(events) / N):
@@ -505,43 +583,56 @@ class EventsStore(SQLBaseStore):
break
sql = (
- "SELECT e.internal_metadata, e.json, r.redacts, rej.event_id "
+ "SELECT "
+ " e.event_id as event_id, "
+ " e.internal_metadata,"
+ " e.json,"
+ " r.redacts as redacts,"
+ " rej.event_id as rejects "
" FROM event_json as e"
" LEFT JOIN rejections as rej USING (event_id)"
" LEFT JOIN redactions as r ON e.event_id = r.redacts"
" WHERE e.event_id IN (%s)"
) % (",".join(["?"]*len(evs)),)
- if txn:
- txn.execute(sql, evs)
- rows.extend(txn.fetchall())
- else:
- res = yield self._execute("_fetch_events", None, sql, *evs)
- rows.extend(res)
+ txn.execute(sql, evs)
+ rows.extend(self.cursor_to_dict(txn))
+
+ return rows
+
+ @defer.inlineCallbacks
+ def _fetch_events(self, txn, events, check_redacted=True,
+ get_prev_content=False, allow_rejected=False):
+ if not events:
+ defer.returnValue({})
+
+ if txn:
+ rows = self._fetch_event_rows(
+ txn, events,
+ )
+ else:
+ rows = yield self.runInteraction(
+ self._fetch_event_rows,
+ events,
+ )
res = yield defer.gatherResults(
[
defer.maybeDeferred(
self._get_event_from_row,
txn,
- row[0], row[1], row[2],
+ row["internal_metadata"], row["json"], row["redacts"],
check_redacted=check_redacted,
get_prev_content=get_prev_content,
- rejected_reason=row[3],
+ rejected_reason=row["rejects"],
)
for row in rows
- ],
- consumeErrors=True,
+ ]
)
- for e in res:
- self._get_event_cache.prefill(
- e.event_id, check_redacted, get_prev_content, e
- )
-
defer.returnValue({
- e.event_id: e
- for e in res if e
+ r.event_id: r
+ for r in res
})
@defer.inlineCallbacks
@@ -611,6 +702,10 @@ class EventsStore(SQLBaseStore):
if prev:
ev.unsigned["prev_content"] = prev.get_dict()["content"]
+ self._get_event_cache.prefill(
+ ev.event_id, check_redacted, get_prev_content, ev
+ )
+
defer.returnValue(ev)
def _parse_events(self, rows):
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index b9afb3364d..260714ccc2 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -80,16 +80,16 @@ class Clock(object):
def stop_looping_call(self, loop):
loop.stop()
- def call_later(self, delay, callback):
+ def call_later(self, delay, callback, *args, **kwargs):
current_context = LoggingContext.current_context()
- def wrapped_callback():
+ def wrapped_callback(*args, **kwargs):
with PreserveLoggingContext():
LoggingContext.thread_local.current_context = current_context
- callback()
+ callback(*args, **kwargs)
with PreserveLoggingContext():
- return reactor.callLater(delay, wrapped_callback)
+ return reactor.callLater(delay, wrapped_callback, *args, **kwargs)
def cancel_call_later(self, timer):
timer.cancel()
|