| diff --git a/synapse/notifier.py b/synapse/notifier.py
index 0b50898f3f..7b9bae1056 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -316,7 +316,7 @@ class Notifier(object):
                     user, getattr(from_token, keyname), limit,
                 )
                 events.extend(stuff)
-                end_token = from_token.copy_and_replace(keyname, new_key)
+                end_token = end_token.copy_and_replace(keyname, new_key)
 
             if events:
                 defer.returnValue((events, (from_token, end_token)))
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
 index 38395c66ab..626a5eaf6e 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -23,6 +23,7 @@ from synapse.crypto.event_signing import compute_event_reference_hash
 
 from syutil.base64util import decode_base64
 from syutil.jsonutil import encode_canonical_json
+from contextlib import contextmanager
 
 import logging
 
@@ -41,17 +42,25 @@ class EventsStore(SQLBaseStore):
             self.min_token -= 1
             stream_ordering = self.min_token
 
+        if stream_ordering is None:
+            stream_ordering_manager = yield self._stream_id_gen.get_next(self)
+        else:
+            @contextmanager
+            def stream_ordering_manager():
+                yield stream_ordering
+
         try:
-            yield self.runInteraction(
-                "persist_event",
-                self._persist_event_txn,
-                event=event,
-                context=context,
-                backfilled=backfilled,
-                stream_ordering=stream_ordering,
-                is_new_state=is_new_state,
-                current_state=current_state,
-            )
+            with stream_ordering_manager as stream_ordering:
+                yield self.runInteraction(
+                    "persist_event",
+                    self._persist_event_txn,
+                    event=event,
+                    context=context,
+                    backfilled=backfilled,
+                    stream_ordering=stream_ordering,
+                    is_new_state=is_new_state,
+                    current_state=current_state,
+                )
         except _RollbackButIsFineException:
             pass
 
@@ -95,15 +104,6 @@ class EventsStore(SQLBaseStore):
         # Remove the any existing cache entries for the event_id
         txn.call_after(self._invalidate_get_event_cache, event.event_id)
 
-        if stream_ordering is None:
-            with self._stream_id_gen.get_next_txn(txn) as stream_ordering:
-                return self._persist_event_txn(
-                    txn, event, context, backfilled,
-                    stream_ordering=stream_ordering,
-                    is_new_state=is_new_state,
-                    current_state=current_state,
-                )
-
         # We purposefully do this first since if we include a `current_state`
         # key, we *want* to update the `current_state_events` table
         if current_state:
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
 index e40eb8a8c4..89d1643f10 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -78,14 +78,18 @@ class StreamIdGenerator(object):
         self._current_max = None
         self._unfinished_ids = deque()
 
-    def get_next_txn(self, txn):
+    @defer.inlineCallbacks
+    def get_next(self, store):
         """
         Usage:
-            with stream_id_gen.get_next_txn(txn) as stream_id:
+            with yield stream_id_gen.get_next as stream_id:
                 # ... persist event ...
         """
         if not self._current_max:
-            self._get_or_compute_current_max(txn)
+            yield store.runInteraction(
+                "_compute_current_max",
+                self._get_or_compute_current_max,
+            )
 
         with self._lock:
             self._current_max += 1
@@ -101,7 +105,7 @@ class StreamIdGenerator(object):
                 with self._lock:
                     self._unfinished_ids.remove(next_id)
 
-        return manager()
+        defer.returnValue(manager())
 
     @defer.inlineCallbacks
     def get_max_token(self, store):
 |