diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index f172c2690a..7a09c33613 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -119,25 +119,15 @@ class DataStore(RoomMemberStore, RoomStore,
@defer.inlineCallbacks
def get_event(self, event_id, allow_none=False):
- events_dict = yield self._simple_select_one(
- "events",
- {"event_id": event_id},
- [
- "event_id",
- "type",
- "room_id",
- "content",
- "unrecognized_keys",
- "depth",
- ],
- allow_none=allow_none,
- )
+ events = yield self._get_events([event_id])
- if not events_dict:
- defer.returnValue(None)
+ if not events:
+ if allow_none:
+ defer.returnValue(None)
+ else:
+ raise RuntimeError("Could not find event %s" % (event_id,))
- event = yield self._parse_events([events_dict])
- defer.returnValue(event[0])
+ defer.returnValue(events[0])
@log_function
def _persist_event_txn(self, txn, event, context, backfilled,
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index b0f454b90b..72f88cb2aa 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -461,32 +461,18 @@ class SQLBaseStore(object):
**d
)
- def _get_events_txn(self, txn, event_ids):
- # FIXME (erikj): This should be batched?
-
- sql = "SELECT * FROM events WHERE event_id = ? ORDER BY rowid asc"
-
- event_rows = []
- for e_id in event_ids:
- c = txn.execute(sql, (e_id,))
- event_rows.extend(self.cursor_to_dict(c))
-
- return self._parse_events_txn(txn, event_rows)
-
- def _parse_events(self, rows):
+ def _get_events(self, event_ids):
return self.runInteraction(
- "_parse_events", self._parse_events_txn, rows
+ "_get_events", self._get_events_txn, event_ids
)
- def _parse_events_txn(self, txn, rows):
- event_ids = [r["event_id"] for r in rows]
-
+ def _get_events_txn(self, txn, event_ids):
events = []
- for event_id in event_ids:
+ for e_id in event_ids:
js = self._simple_select_one_onecol_txn(
txn,
table="event_json",
- keyvalues={"event_id": event_id},
+ keyvalues={"event_id": e_id},
retcol="json",
allow_none=True,
)
@@ -516,6 +502,16 @@ class SQLBaseStore(object):
return events
+ def _parse_events(self, rows):
+ return self.runInteraction(
+ "_parse_events", self._parse_events_txn, rows
+ )
+
+ def _parse_events_txn(self, txn, rows):
+ event_ids = [r["event_id"] for r in rows]
+
+ return self._get_events_txn(txn, event_ids)
+
def _has_been_redacted_txn(self, txn, event):
sql = "SELECT event_id FROM redactions WHERE redacts = ?"
txn.execute(sql, (event.event_id,))
|