diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index a2f8310388..e30f9c76d4 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -80,6 +80,10 @@ class SQLBaseStore(metaclass=ABCMeta):
)
self._attempt_to_invalidate_cache("get_local_users_in_room", (room_id,))
+ # There's no easy way of invalidating this cache for just the users
+ # that have changed, so we just clear the entire thing.
+ self._attempt_to_invalidate_cache("does_pair_of_users_share_a_room", None)
+
for user_id in members_changed:
self._attempt_to_invalidate_cache(
"get_user_in_room_with_profile", (room_id, user_id)
diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py
index e08f956e6e..1e35046e07 100644
--- a/synapse/storage/controllers/state.py
+++ b/synapse/storage/controllers/state.py
@@ -82,13 +82,15 @@ class StateStorageController:
return state_group_delta.prev_group, state_group_delta.delta_ids
async def get_state_groups_ids(
- self, _room_id: str, event_ids: Collection[str]
+ self, _room_id: str, event_ids: Collection[str], await_full_state: bool = True
) -> Dict[int, MutableStateMap[str]]:
"""Get the event IDs of all the state for the state groups for the given events
Args:
_room_id: id of the room for these events
event_ids: ids of the events
+ await_full_state: if `True`, will block if we do not yet have complete
+ state at these events.
Returns:
dict of state_group_id -> (dict of (type, state_key) -> event id)
@@ -100,7 +102,9 @@ class StateStorageController:
if not event_ids:
return {}
- event_to_groups = await self.get_state_group_for_events(event_ids)
+ event_to_groups = await self.get_state_group_for_events(
+ event_ids, await_full_state=await_full_state
+ )
groups = set(event_to_groups.values())
group_to_state = await self.stores.state._get_state_for_groups(groups)
@@ -334,6 +338,10 @@ class StateStorageController:
event_ids: events to get state groups for
await_full_state: if true, will block if we do not yet have complete
state at these events.
+
+ Raises:
+ RuntimeError if we don't have a state group for one or more of the events
+ (ie. they are outliers or unknown)
"""
if await_full_state:
await self._partial_state_events_tracker.await_full_state(event_ids)
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 5914a35420..29c99c6357 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -2110,11 +2110,29 @@ class EventsWorkerStore(SQLBaseStore):
def _get_partial_state_events_batch_txn(
txn: LoggingTransaction, room_id: str
) -> List[str]:
+ # we want to work through the events from oldest to newest, so
+ # we only want events whose prev_events do *not* have partial state - hence
+ # the 'NOT EXISTS' clause in the below.
+ #
+ # This is necessary because ordering by stream ordering isn't quite enough
+ # to ensure that we work from oldest to newest event (in particular,
+ # if an event is initially persisted as an outlier and later de-outliered,
+ # it can end up with a lower stream_ordering than its prev_events).
+ #
+ # Typically this means we'll only return one event per batch, but that's
+ # hard to do much about.
+ #
+ # See also: https://github.com/matrix-org/synapse/issues/13001
txn.execute(
"""
SELECT event_id FROM partial_state_events AS pse
JOIN events USING (event_id)
- WHERE pse.room_id = ?
+ WHERE pse.room_id = ? AND
+ NOT EXISTS(
+ SELECT 1 FROM event_edges AS ee
+ JOIN partial_state_events AS prev_pse ON (prev_pse.event_id=ee.prev_event_id)
+ WHERE ee.event_id=pse.event_id
+ )
ORDER BY events.stream_ordering
LIMIT 100
""",
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index b457bc189e..7bd27790eb 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -62,7 +62,6 @@ class RelationsWorkerStore(SQLBaseStore):
room_id: str,
relation_type: Optional[str] = None,
event_type: Optional[str] = None,
- aggregation_key: Optional[str] = None,
limit: int = 5,
direction: str = "b",
from_token: Optional[StreamToken] = None,
@@ -76,7 +75,6 @@ class RelationsWorkerStore(SQLBaseStore):
room_id: The room the event belongs to.
relation_type: Only fetch events with this relation type, if given.
event_type: Only fetch events with this event type, if given.
- aggregation_key: Only fetch events with this aggregation key, if given.
limit: Only fetch the most recent `limit` events.
direction: Whether to fetch the most recent first (`"b"`) or the
oldest first (`"f"`).
@@ -105,10 +103,6 @@ class RelationsWorkerStore(SQLBaseStore):
where_clause.append("type = ?")
where_args.append(event_type)
- if aggregation_key:
- where_clause.append("aggregation_key = ?")
- where_args.append(aggregation_key)
-
pagination_clause = generate_pagination_where_clause(
direction=direction,
column_names=("topological_ordering", "stream_ordering"),
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index d6d485507b..0f1f0d11ea 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -207,7 +207,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
def _construct_room_type_where_clause(
self, room_types: Union[List[Union[str, None]], None]
) -> Tuple[Union[str, None], List[str]]:
- if not room_types or not self.config.experimental.msc3827_enabled:
+ if not room_types:
return None, []
else:
# We use None when we want get rooms without a type
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index df6b82660e..e2cccc688c 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -21,6 +21,7 @@ from typing import (
FrozenSet,
Iterable,
List,
+ Mapping,
Optional,
Set,
Tuple,
@@ -55,6 +56,7 @@ from synapse.types import JsonDict, PersistedEventPosition, StateMap, get_domain
from synapse.util.async_helpers import Linearizer
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
+from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure
if TYPE_CHECKING:
@@ -183,7 +185,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
self._check_safe_current_state_events_membership_updated_txn,
)
- @cached(max_entries=100000, iterable=True, prune_unread_entries=False)
+ @cached(max_entries=100000, iterable=True)
async def get_users_in_room(self, room_id: str) -> List[str]:
return await self.db_pool.runInteraction(
"get_users_in_room", self.get_users_in_room_txn, room_id
@@ -561,7 +563,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return results_dict.get("membership"), results_dict.get("event_id")
- @cached(max_entries=500000, iterable=True, prune_unread_entries=False)
+ @cached(max_entries=500000, iterable=True)
async def get_rooms_for_user_with_stream_ordering(
self, user_id: str
) -> FrozenSet[GetRoomsForUserWithStreamOrdering]:
@@ -732,25 +734,76 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
return frozenset(r.room_id for r in rooms)
- @cached(
- max_entries=500000,
- cache_context=True,
- iterable=True,
- prune_unread_entries=False,
+ @cached(max_entries=10000)
+ async def does_pair_of_users_share_a_room(
+ self, user_id: str, other_user_id: str
+ ) -> bool:
+ raise NotImplementedError()
+
+ @cachedList(
+ cached_method_name="does_pair_of_users_share_a_room", list_name="other_user_ids"
)
- async def get_users_who_share_room_with_user(
- self, user_id: str, cache_context: _CacheContext
+ async def _do_users_share_a_room(
+ self, user_id: str, other_user_ids: Collection[str]
+ ) -> Mapping[str, Optional[bool]]:
+ """Return mapping from user ID to whether they share a room with the
+ given user.
+
+ Note: `None` and `False` are equivalent and mean they don't share a
+ room.
+ """
+
+ def do_users_share_a_room_txn(
+ txn: LoggingTransaction, user_ids: Collection[str]
+ ) -> Dict[str, bool]:
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "state_key", user_ids
+ )
+
+ # This query works by fetching both the list of rooms for the target
+ # user and the set of other users, and then checking if there is any
+ # overlap.
+ sql = f"""
+ SELECT b.state_key
+ FROM (
+ SELECT room_id FROM current_state_events
+ WHERE type = 'm.room.member' AND membership = 'join' AND state_key = ?
+ ) AS a
+ INNER JOIN (
+ SELECT room_id, state_key FROM current_state_events
+ WHERE type = 'm.room.member' AND membership = 'join' AND {clause}
+ ) AS b using (room_id)
+ LIMIT 1
+ """
+
+ txn.execute(sql, (user_id, *args))
+ return {u: True for u, in txn}
+
+ to_return = {}
+ for batch_user_ids in batch_iter(other_user_ids, 1000):
+ res = await self.db_pool.runInteraction(
+ "do_users_share_a_room", do_users_share_a_room_txn, batch_user_ids
+ )
+ to_return.update(res)
+
+ return to_return
+
+ async def do_users_share_a_room(
+ self, user_id: str, other_user_ids: Collection[str]
) -> Set[str]:
+ """Return the set of users who share a room with the first users"""
+
+ user_dict = await self._do_users_share_a_room(user_id, other_user_ids)
+
+ return {u for u, share_room in user_dict.items() if share_room}
+
+ async def get_users_who_share_room_with_user(self, user_id: str) -> Set[str]:
"""Returns the set of users who share a room with `user_id`"""
- room_ids = await self.get_rooms_for_user(
- user_id, on_invalidate=cache_context.invalidate
- )
+ room_ids = await self.get_rooms_for_user(user_id)
user_who_share_room = set()
for room_id in room_ids:
- user_ids = await self.get_users_in_room(
- room_id, on_invalidate=cache_context.invalidate
- )
+ user_ids = await self.get_users_in_room(room_id)
user_who_share_room.update(user_ids)
return user_who_share_room
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 9674c4a757..f70705a0af 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -419,13 +419,15 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
# anything that was rejected should have the same state as its
# predecessor.
if context.rejected:
- assert context.state_group == context.state_group_before_event
+ state_group = context.state_group_before_event
+ else:
+ state_group = context.state_group
self.db_pool.simple_update_txn(
txn,
table="event_to_state_groups",
keyvalues={"event_id": event.event_id},
- updatevalues={"state_group": context.state_group},
+ updatevalues={"state_group": state_group},
)
self.db_pool.simple_delete_one_txn(
@@ -440,7 +442,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
txn.call_after(
self._get_state_group_for_event.prefill,
(event.event_id,),
- context.state_group,
+ state_group,
)
|