diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 2d7511d613..dd4e83a2ad 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -192,6 +192,10 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self.get_unread_event_push_actions_by_room_for_user.invalidate((room_id,))
+ # The `_get_membership_from_event_id` is immutable, except for the
+ # case where we look up an event *before* persisting it.
+ self._get_membership_from_event_id.invalidate((event_id,))
+
if not backfilled:
self._events_stream_cache.entity_has_changed(room_id, stream_ordering)
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 1f60aef180..d253243125 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1745,6 +1745,13 @@ class PersistEventsStore:
(event.state_key,),
)
+ # The `_get_membership_from_event_id` is immutable, except for the
+ # case where we look up an event *before* persisting it.
+ txn.call_after(
+ self.store._get_membership_from_event_id.invalidate,
+ (event.event_id,),
+ )
+
# We update the local_current_membership table only if the event is
# "current", i.e., its something that has just happened.
#
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index bef675b845..3248da5356 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -63,6 +63,14 @@ _MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update"
_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership"
+@attr.s(frozen=True, slots=True, auto_attribs=True)
+class EventIdMembership:
+ """Returned by `get_membership_from_event_ids`"""
+
+ user_id: str
+ membership: str
+
+
class RoomMemberWorkerStore(EventsWorkerStore):
def __init__(
self,
@@ -772,7 +780,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
retcols=("user_id", "display_name", "avatar_url", "event_id"),
keyvalues={"membership": Membership.JOIN},
batch_size=500,
- desc="_get_membership_from_event_ids",
+ desc="_get_joined_profiles_from_event_ids",
)
return {
@@ -1000,12 +1008,26 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return set(room_ids)
+ @cached(max_entries=5000)
+ async def _get_membership_from_event_id(
+ self, member_event_id: str
+ ) -> Optional[EventIdMembership]:
+ raise NotImplementedError()
+
+ @cachedList(
+ cached_method_name="_get_membership_from_event_id", list_name="member_event_ids"
+ )
async def get_membership_from_event_ids(
self, member_event_ids: Iterable[str]
- ) -> List[dict]:
- """Get user_id and membership of a set of event IDs."""
+ ) -> Dict[str, Optional[EventIdMembership]]:
+ """Get user_id and membership of a set of event IDs.
+
+ Returns:
+ Mapping from event ID to `EventIdMembership` if the event is a
+ membership event, otherwise the value is None.
+ """
- return await self.db_pool.simple_select_many_batch(
+ rows = await self.db_pool.simple_select_many_batch(
table="room_memberships",
column="event_id",
iterable=member_event_ids,
@@ -1015,6 +1037,13 @@ class RoomMemberWorkerStore(EventsWorkerStore):
desc="get_membership_from_event_ids",
)
+ return {
+ row["event_id"]: EventIdMembership(
+ membership=row["membership"], user_id=row["user_id"]
+ )
+ for row in rows
+ }
+
async def is_local_host_in_room_ignoring_users(
self, room_id: str, ignore_users: Collection[str]
) -> bool:
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index 7d543fdbe0..b402922817 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -1023,8 +1023,13 @@ class EventsPersistenceStorage:
# Check if any of the changes that we don't have events for are joins.
if events_to_check:
- rows = await self.main_store.get_membership_from_event_ids(events_to_check)
- is_still_joined = any(row["membership"] == Membership.JOIN for row in rows)
+ members = await self.main_store.get_membership_from_event_ids(
+ events_to_check
+ )
+ is_still_joined = any(
+ member and member.membership == Membership.JOIN
+ for member in members.values()
+ )
if is_still_joined:
return True
@@ -1060,9 +1065,11 @@ class EventsPersistenceStorage:
), event_id in current_state.items()
if typ == EventTypes.Member and not self.is_mine_id(state_key)
]
- rows = await self.main_store.get_membership_from_event_ids(remote_event_ids)
+ members = await self.main_store.get_membership_from_event_ids(remote_event_ids)
potentially_left_users.update(
- row["user_id"] for row in rows if row["membership"] == Membership.JOIN
+ member.user_id
+ for member in members.values()
+ if member and member.membership == Membership.JOIN
)
return False
|