diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index a6fc4d6ea4..f7b4def9ec 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -875,19 +875,11 @@ class SQLBaseStore(object):
def _get_events_txn(self, txn, event_ids, check_redacted=True,
get_prev_content=False):
- if not event_ids:
- return []
-
- events = [
- self._get_event_txn(
- txn, event_id,
- check_redacted=check_redacted,
- get_prev_content=get_prev_content
- )
- for event_id in event_ids
- ]
-
- return [e for e in events if e]
+ return self._fetch_events_txn(
+ txn, event_ids,
+ check_redacted=check_redacted,
+ get_prev_content=get_prev_content,
+ )
def _invalidate_get_event_cache(self, event_id):
for check_redacted in (False, True):
@@ -950,6 +942,63 @@ class SQLBaseStore(object):
else:
return None
+ def _fetch_events_txn(self, txn, events, check_redacted=True,
+ get_prev_content=False, allow_rejected=False):
+ if not events:
+ return []
+
+ event_map = {}
+
+ for event_id in events:
+ try:
+ ret = self._get_event_cache.get(event_id, check_redacted, get_prev_content)
+
+ if allow_rejected or not ret.rejected_reason:
+ event_map[event_id] = ret
+ else:
+ return None
+ except KeyError:
+ pass
+
+ missing_events = [
+ e for e in events
+ if e not in event_map
+ ]
+
+ if missing_events:
+ sql = (
+ "SELECT e.internal_metadata, e.json, r.redacts, rej.event_id "
+ " FROM event_json as e"
+ " LEFT JOIN rejections as rej USING (event_id)"
+ " LEFT JOIN redactions as r ON e.event_id = r.redacts"
+ " WHERE e.event_id IN (%s)"
+ ) % (",".join(["?"]*len(missing_events)),)
+
+ txn.execute(sql, missing_events)
+ rows = txn.fetchall()
+
+ res = [
+ self._get_event_from_row_txn(
+ txn, row[0], row[1], row[2],
+ check_redacted=check_redacted,
+ get_prev_content=get_prev_content,
+ rejected_reason=row[3],
+ )
+ for row in rows
+ ]
+
+ event_map.update({
+ e.event_id: e
+ for e in res if e
+ })
+
+ for e in res:
+ self._get_event_cache.prefill(
+ e.event_id, check_redacted, get_prev_content, e
+ )
+
+ return [event_map[e_id] for e_id in events if e_id in event_map]
+
def _get_event_from_row_txn(self, txn, internal_metadata, js, redacted,
check_redacted=True, get_prev_content=False,
rejected_reason=None):
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 6d0ecf8dd9..a80a947436 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -83,31 +83,11 @@ class StateStore(SQLBaseStore):
f,
)
- def fetch_events(txn, events):
- sql = (
- "SELECT e.internal_metadata, e.json, r.redacts, rej.event_id "
- " FROM event_json as e"
- " LEFT JOIN rejections as rej USING (event_id)"
- " LEFT JOIN redactions as r ON e.event_id = r.redacts"
- " WHERE e.event_id IN (%s)"
- ) % (",".join(["?"]*len(events)),)
-
- txn.execute(sql, events)
- rows = txn.fetchall()
-
- return [
- self._get_event_from_row_txn(
- txn, row[0], row[1], row[2],
- rejected_reason=row[3],
- )
- for row in rows
- ]
-
@defer.inlineCallbacks
def c(vals):
vals[:] = yield self.runInteraction(
"_get_state_groups_ev",
- fetch_events, vals
+ self._fetch_events_txn, vals
)
yield defer.gatherResults(
|