diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index aea96e9d24..84f844b79e 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -292,6 +292,7 @@ class RelationsWorkerStore(SQLBaseStore):
to_device_key=0,
device_list_key=0,
groups_key=0,
+ un_partial_stated_rooms_key=0,
)
return events[:limit], next_token
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 6a65b2a89b..3aa7b94560 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -26,6 +26,7 @@ from typing import (
Mapping,
Optional,
Sequence,
+ Set,
Tuple,
Union,
cast,
@@ -1294,10 +1295,44 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
instance_name
)
+ async def get_un_partial_stated_rooms_between(
+ self, last_id: int, current_id: int, room_ids: Collection[str]
+ ) -> Set[str]:
+ """Get all rooms that got un partial stated between `last_id` exclusive and
+ `current_id` inclusive.
+
+ Returns:
+ The list of room ids.
+ """
+
+ if last_id == current_id:
+ return set()
+
+ def _get_un_partial_stated_rooms_between_txn(
+ txn: LoggingTransaction,
+ ) -> Set[str]:
+ sql = """
+ SELECT DISTINCT room_id FROM un_partial_stated_room_stream
+ WHERE ? < stream_id AND stream_id <= ? AND
+ """
+
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "room_id", room_ids
+ )
+
+ txn.execute(sql + clause, [last_id, current_id] + args)
+
+ return {r[0] for r in txn}
+
+ return await self.db_pool.runInteraction(
+ "get_un_partial_stated_rooms_between",
+ _get_un_partial_stated_rooms_between_txn,
+ )
+
async def get_un_partial_stated_rooms_from_stream(
self, instance_name: str, last_id: int, current_id: int, limit: int
) -> Tuple[List[Tuple[int, Tuple[str]]], int, bool]:
- """Get updates for caches replication stream.
+ """Get updates for un partial stated rooms replication stream.
Args:
instance_name: The writer we want to fetch updates from. Unused
@@ -2304,16 +2339,16 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
(room_id,),
)
- async def clear_partial_state_room(self, room_id: str) -> bool:
+ async def clear_partial_state_room(self, room_id: str) -> Optional[int]:
"""Clears the partial state flag for a room.
Args:
room_id: The room whose partial state flag is to be cleared.
Returns:
- `True` if the partial state flag has been cleared successfully.
+ The corresponding stream id for the un-partial-stated rooms stream.
- `False` if the partial state flag could not be cleared because the room
+ `None` if the partial state flag could not be cleared because the room
still contains events with partial state.
"""
try:
@@ -2324,7 +2359,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
room_id,
un_partial_state_room_stream_id,
)
- return True
+ return un_partial_state_room_stream_id
except self.db_pool.engine.module.IntegrityError as e:
# Assume that any `IntegrityError`s are due to partial state events.
logger.info(
@@ -2332,7 +2367,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
room_id,
e,
)
- return False
+ return None
def _clear_partial_state_room_txn(
self,
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index f02c1d7ea7..8e2ba7b7b4 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -15,6 +15,7 @@
import logging
from typing import (
TYPE_CHECKING,
+ AbstractSet,
Collection,
Dict,
FrozenSet,
@@ -47,7 +48,13 @@ from synapse.storage.roommember import (
ProfileInfo,
RoomsForUser,
)
-from synapse.types import JsonDict, PersistedEventPosition, StateMap, get_domain_from_id
+from synapse.types import (
+ JsonDict,
+ PersistedEventPosition,
+ StateMap,
+ StrCollection,
+ get_domain_from_id,
+)
from synapse.util.async_helpers import Linearizer
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
@@ -385,7 +392,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
self,
user_id: str,
membership_list: Collection[str],
- excluded_rooms: Optional[List[str]] = None,
+ excluded_rooms: StrCollection = (),
) -> List[RoomsForUser]:
"""Get all the rooms for this *local* user where the membership for this user
matches one in the membership list.
@@ -412,10 +419,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
# Now we filter out forgotten and excluded rooms
- rooms_to_exclude: Set[str] = await self.get_forgotten_rooms_for_user(user_id)
+ rooms_to_exclude = await self.get_forgotten_rooms_for_user(user_id)
if excluded_rooms is not None:
- rooms_to_exclude.update(set(excluded_rooms))
+ # Take a copy to avoid mutating the in-cache set
+ rooms_to_exclude = set(rooms_to_exclude)
+ rooms_to_exclude.update(excluded_rooms)
return [room for room in rooms if room.room_id not in rooms_to_exclude]
@@ -1169,7 +1178,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return count == 0
@cached()
- async def get_forgotten_rooms_for_user(self, user_id: str) -> Set[str]:
+ async def get_forgotten_rooms_for_user(self, user_id: str) -> AbstractSet[str]:
"""Gets all rooms the user has forgotten.
Args:
|