diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 552e7ca35b..5233430028 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -101,30 +101,16 @@ class EventsStore(SQLBaseStore):
@defer.inlineCallbacks
@log_function
- def persist_event(self, event, context, backfilled=False,
+ def persist_event(self, event, context,
is_new_state=True, current_state=None):
- stream_ordering = None
- if backfilled:
- self.min_stream_token -= 1
- stream_ordering = self.min_stream_token
-
- if stream_ordering is None:
- stream_ordering_manager = self._stream_id_gen.get_next()
- else:
- @contextmanager
- def stream_ordering_manager():
- yield stream_ordering
- stream_ordering_manager = stream_ordering_manager()
-
try:
- with stream_ordering_manager as stream_ordering:
+ with self._stream_id_gen.get_next() as stream_ordering:
event.internal_metadata.stream_ordering = stream_ordering
yield self.runInteraction(
"persist_event",
self._persist_event_txn,
event=event,
context=context,
- backfilled=backfilled,
is_new_state=is_new_state,
current_state=current_state,
)
@@ -165,13 +151,38 @@ class EventsStore(SQLBaseStore):
defer.returnValue(events[0] if events else None)
+ @defer.inlineCallbacks
+ def get_events(self, event_ids, check_redacted=True,
+ get_prev_content=False, allow_rejected=False):
+ """Get events from the database
+
+ Args:
+ event_ids (list): The event_ids of the events to fetch
+ check_redacted (bool): If True, check if event has been redacted
+ and redact it.
+ get_prev_content (bool): If True and event is a state event,
+ include the previous states content in the unsigned field.
+ allow_rejected (bool): If True return rejected events.
+
+ Returns:
+ Deferred : Dict from event_id to event.
+ """
+ events = yield self._get_events(
+ event_ids,
+ check_redacted=check_redacted,
+ get_prev_content=get_prev_content,
+ allow_rejected=allow_rejected,
+ )
+
+ defer.returnValue({e.event_id: e for e in events})
+
@log_function
- def _persist_event_txn(self, txn, event, context, backfilled,
+ def _persist_event_txn(self, txn, event, context,
is_new_state=True, current_state=None):
# We purposefully do this first since if we include a `current_state`
# key, we *want* to update the `current_state_events` table
if current_state:
- txn.call_after(self.get_current_state_for_key.invalidate_all)
+ txn.call_after(self._get_current_state_for_key.invalidate_all)
txn.call_after(self.get_rooms_for_user.invalidate_all)
txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
@@ -198,7 +209,7 @@ class EventsStore(SQLBaseStore):
return self._persist_events_txn(
txn,
[(event, context)],
- backfilled=backfilled,
+ backfilled=False,
is_new_state=is_new_state,
)
@@ -455,7 +466,7 @@ class EventsStore(SQLBaseStore):
for event, _ in state_events_and_contexts:
if not context.rejected:
txn.call_after(
- self.get_current_state_for_key.invalidate,
+ self._get_current_state_for_key.invalidate,
(event.room_id, event.type, event.state_key,)
)
@@ -526,6 +537,9 @@ class EventsStore(SQLBaseStore):
if not event_ids:
defer.returnValue([])
+ event_id_list = event_ids
+ event_ids = set(event_ids)
+
event_map = self._get_events_from_cache(
event_ids,
check_redacted=check_redacted,
@@ -535,23 +549,18 @@ class EventsStore(SQLBaseStore):
missing_events_ids = [e for e in event_ids if e not in event_map]
- if not missing_events_ids:
- defer.returnValue([
- event_map[e_id] for e_id in event_ids
- if e_id in event_map and event_map[e_id]
- ])
-
- missing_events = yield self._enqueue_events(
- missing_events_ids,
- check_redacted=check_redacted,
- get_prev_content=get_prev_content,
- allow_rejected=allow_rejected,
- )
+ if missing_events_ids:
+ missing_events = yield self._enqueue_events(
+ missing_events_ids,
+ check_redacted=check_redacted,
+ get_prev_content=get_prev_content,
+ allow_rejected=allow_rejected,
+ )
- event_map.update(missing_events)
+ event_map.update(missing_events)
defer.returnValue([
- event_map[e_id] for e_id in event_ids
+ event_map[e_id] for e_id in event_id_list
if e_id in event_map and event_map[e_id]
])
|