diff options
Diffstat (limited to 'synapse/storage/_base.py')
-rw-r--r-- | synapse/storage/_base.py | 75 |
1 files changed, 62 insertions, 13 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index a6fc4d6ea4..f7b4def9ec 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -875,19 +875,11 @@ class SQLBaseStore(object): def _get_events_txn(self, txn, event_ids, check_redacted=True, get_prev_content=False): - if not event_ids: - return [] - - events = [ - self._get_event_txn( - txn, event_id, - check_redacted=check_redacted, - get_prev_content=get_prev_content - ) - for event_id in event_ids - ] - - return [e for e in events if e] + return self._fetch_events_txn( + txn, event_ids, + check_redacted=check_redacted, + get_prev_content=get_prev_content, + ) def _invalidate_get_event_cache(self, event_id): for check_redacted in (False, True): @@ -950,6 +942,63 @@ class SQLBaseStore(object): else: return None + def _fetch_events_txn(self, txn, events, check_redacted=True, + get_prev_content=False, allow_rejected=False): + if not events: + return [] + + event_map = {} + + for event_id in events: + try: + ret = self._get_event_cache.get(event_id, check_redacted, get_prev_content) + + if allow_rejected or not ret.rejected_reason: + event_map[event_id] = ret + else: + return None + except KeyError: + pass + + missing_events = [ + e for e in events + if e not in event_map + ] + + if missing_events: + sql = ( + "SELECT e.internal_metadata, e.json, r.redacts, rej.event_id " + " FROM event_json as e" + " LEFT JOIN rejections as rej USING (event_id)" + " LEFT JOIN redactions as r ON e.event_id = r.redacts" + " WHERE e.event_id IN (%s)" + ) % (",".join(["?"]*len(missing_events)),) + + txn.execute(sql, missing_events) + rows = txn.fetchall() + + res = [ + self._get_event_from_row_txn( + txn, row[0], row[1], row[2], + check_redacted=check_redacted, + get_prev_content=get_prev_content, + rejected_reason=row[3], + ) + for row in rows + ] + + event_map.update({ + e.event_id: e + for e in res if e + }) + + for e in res: + self._get_event_cache.prefill( + e.event_id, check_redacted, get_prev_content, e + ) + + return [event_map[e_id] for e_id in events if e_id in event_map] + def _get_event_from_row_txn(self, txn, internal_metadata, js, redacted, check_redacted=True, get_prev_content=False, rejected_reason=None): |