diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 38395c66ab..626a5eaf6e 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -23,6 +23,7 @@ from synapse.crypto.event_signing import compute_event_reference_hash
from syutil.base64util import decode_base64
from syutil.jsonutil import encode_canonical_json
+from contextlib import contextmanager
import logging
@@ -41,17 +42,25 @@ class EventsStore(SQLBaseStore):
self.min_token -= 1
stream_ordering = self.min_token
+ if stream_ordering is None:
+ stream_ordering_manager = yield self._stream_id_gen.get_next(self)
+ else:
+ @contextmanager
+ def stream_ordering_manager():
+ yield stream_ordering
+
try:
- yield self.runInteraction(
- "persist_event",
- self._persist_event_txn,
- event=event,
- context=context,
- backfilled=backfilled,
- stream_ordering=stream_ordering,
- is_new_state=is_new_state,
- current_state=current_state,
- )
+ with stream_ordering_manager as stream_ordering:
+ yield self.runInteraction(
+ "persist_event",
+ self._persist_event_txn,
+ event=event,
+ context=context,
+ backfilled=backfilled,
+ stream_ordering=stream_ordering,
+ is_new_state=is_new_state,
+ current_state=current_state,
+ )
except _RollbackButIsFineException:
pass
@@ -95,15 +104,6 @@ class EventsStore(SQLBaseStore):
# Remove the any existing cache entries for the event_id
txn.call_after(self._invalidate_get_event_cache, event.event_id)
- if stream_ordering is None:
- with self._stream_id_gen.get_next_txn(txn) as stream_ordering:
- return self._persist_event_txn(
- txn, event, context, backfilled,
- stream_ordering=stream_ordering,
- is_new_state=is_new_state,
- current_state=current_state,
- )
-
# We purposefully do this first since if we include a `current_state`
# key, we *want* to update the `current_state_events` table
if current_state:
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index e40eb8a8c4..89d1643f10 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -78,14 +78,18 @@ class StreamIdGenerator(object):
self._current_max = None
self._unfinished_ids = deque()
- def get_next_txn(self, txn):
+ @defer.inlineCallbacks
+ def get_next(self, store):
"""
Usage:
- with stream_id_gen.get_next_txn(txn) as stream_id:
+ with yield stream_id_gen.get_next as stream_id:
# ... persist event ...
"""
if not self._current_max:
- self._get_or_compute_current_max(txn)
+ yield store.runInteraction(
+ "_compute_current_max",
+ self._get_or_compute_current_max,
+ )
with self._lock:
self._current_max += 1
@@ -101,7 +105,7 @@ class StreamIdGenerator(object):
with self._lock:
self._unfinished_ids.remove(next_id)
- return manager()
+ defer.returnValue(manager())
@defer.inlineCallbacks
def get_max_token(self, store):
|