summary refs log tree commit diff
path: root/synapse/storage/util
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/util')
-rw-r--r--synapse/storage/util/partial_state_events_tracker.py60
1 files changed, 60 insertions, 0 deletions
diff --git a/synapse/storage/util/partial_state_events_tracker.py b/synapse/storage/util/partial_state_events_tracker.py
index a61a951ef0..211437cfaa 100644
--- a/synapse/storage/util/partial_state_events_tracker.py
+++ b/synapse/storage/util/partial_state_events_tracker.py
@@ -21,6 +21,7 @@ from twisted.internet.defer import Deferred
 
 from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
+from synapse.storage.databases.main.room import RoomWorkerStore
 from synapse.util import unwrapFirstError
 
 logger = logging.getLogger(__name__)
@@ -118,3 +119,62 @@ class PartialStateEventsTracker:
                     observer_set.discard(observer)
                     if not observer_set:
                         del self._observers[event_id]
+
+
+class PartialCurrentStateTracker:
+    """Keeps track of which rooms have partial state, after partial-state joins"""
+
+    def __init__(self, store: RoomWorkerStore):
+        self._store = store
+
+        # a map from room id to a set of Deferreds which are waiting for that room to be
+        # un-partial-stated.
+        self._observers: Dict[str, Set[Deferred[None]]] = defaultdict(set)
+
+    def notify_un_partial_stated(self, room_id: str) -> None:
+        """Notify that we now have full current state for a given room
+
+        Unblocks any callers to await_full_state() for that room.
+
+        Args:
+            room_id: the room that now has full current state.
+        """
+        observers = self._observers.pop(room_id, None)
+        if not observers:
+            return
+        logger.info(
+            "Notifying %i things waiting for un-partial-stating of room %s",
+            len(observers),
+            room_id,
+        )
+        with PreserveLoggingContext():
+            for o in observers:
+                o.callback(None)
+
+    async def await_full_state(self, room_id: str) -> None:
+        # We add the deferred immediately so that the DB call to check for
+        # partial state doesn't race when we unpartial the room.
+        d: Deferred[None] = Deferred()
+        self._observers.setdefault(room_id, set()).add(d)
+
+        try:
+            # Check if the room has partial current state or not.
+            has_partial_state = await self._store.is_partial_state_room(room_id)
+            if not has_partial_state:
+                return
+
+            logger.info(
+                "Awaiting un-partial-stating of room %s",
+                room_id,
+            )
+
+            await make_deferred_yieldable(d)
+
+            logger.info("Room has un-partial-stated")
+        finally:
+            # Remove the added observer, and remove the room entry if its empty.
+            ds = self._observers.get(room_id)
+            if ds is not None:
+                ds.discard(d)
+                if not ds:
+                    self._observers.pop(room_id, None)