diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 3c13859faa..2dfe4c0b66 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -460,8 +460,17 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
# Cast safety: this corresponds to the types returned by the query above.
rows.extend(cast(Iterable[Tuple[str, int]], cur))
- # Sort so that we handle rows in order for each instance.
- rows.sort()
+ # Sort by stream_id (ascending, lowest -> highest) so that we handle
+ # rows in order for each instance because we don't want to overwrite
+ # the current_position of an instance to a lower stream ID than
+ # we're actually at.
+ def sort_by_stream_id_key_func(row: Tuple[str, int]) -> int:
+ (instance, stream_id) = row
+ # If `stream_id` is ever `None`, we will see a `TypeError: '<'
+ # not supported between instances of 'NoneType' and 'X'` error.
+ return stream_id
+
+ rows.sort(key=sort_by_stream_id_key_func)
with self._lock:
for (
diff --git a/synapse/storage/util/partial_state_events_tracker.py b/synapse/storage/util/partial_state_events_tracker.py
index 211437cfaa..b4bf49dace 100644
--- a/synapse/storage/util/partial_state_events_tracker.py
+++ b/synapse/storage/util/partial_state_events_tracker.py
@@ -20,6 +20,7 @@ from twisted.internet import defer
from twisted.internet.defer import Deferred
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
+from synapse.logging.opentracing import trace_with_opname
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.room import RoomWorkerStore
from synapse.util import unwrapFirstError
@@ -58,6 +59,7 @@ class PartialStateEventsTracker:
for o in observers:
o.callback(None)
+ @trace_with_opname("PartialStateEventsTracker.await_full_state")
async def await_full_state(self, event_ids: Collection[str]) -> None:
"""Wait for all the given events to have full state.
@@ -151,6 +153,7 @@ class PartialCurrentStateTracker:
for o in observers:
o.callback(None)
+ @trace_with_opname("PartialCurrentStateTracker.await_full_state")
async def await_full_state(self, room_id: str) -> None:
# We add the deferred immediately so that the DB call to check for
# partial state doesn't race when we unpartial the room.
@@ -166,6 +169,7 @@ class PartialCurrentStateTracker:
logger.info(
"Awaiting un-partial-stating of room %s",
room_id,
+ stack_info=True,
)
await make_deferred_yieldable(d)
|