diff options
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r-- | synapse/storage/databases/main/cache.py | 4 | ||||
-rw-r--r-- | synapse/storage/databases/main/events.py | 7 | ||||
-rw-r--r-- | synapse/storage/databases/main/roommember.py | 37 |
3 files changed, 44 insertions, 4 deletions
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index 2d7511d613..dd4e83a2ad 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -192,6 +192,10 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self.get_unread_event_push_actions_by_room_for_user.invalidate((room_id,)) + # The `_get_membership_from_event_id` is immutable, except for the + # case where we look up an event *before* persisting it. + self._get_membership_from_event_id.invalidate((event_id,)) + if not backfilled: self._events_stream_cache.entity_has_changed(room_id, stream_ordering) diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 1f60aef180..d253243125 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1745,6 +1745,13 @@ class PersistEventsStore: (event.state_key,), ) + # The `_get_membership_from_event_id` is immutable, except for the + # case where we look up an event *before* persisting it. + txn.call_after( + self.store._get_membership_from_event_id.invalidate, + (event.event_id,), + ) + # We update the local_current_membership table only if the event is # "current", i.e., its something that has just happened. # diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index bef675b845..3248da5356 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -63,6 +63,14 @@ _MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update" _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership" +@attr.s(frozen=True, slots=True, auto_attribs=True) +class EventIdMembership: + """Returned by `get_membership_from_event_ids`""" + + user_id: str + membership: str + + class RoomMemberWorkerStore(EventsWorkerStore): def __init__( self, @@ -772,7 +780,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): retcols=("user_id", "display_name", "avatar_url", "event_id"), keyvalues={"membership": Membership.JOIN}, batch_size=500, - desc="_get_membership_from_event_ids", + desc="_get_joined_profiles_from_event_ids", ) return { @@ -1000,12 +1008,26 @@ class RoomMemberWorkerStore(EventsWorkerStore): return set(room_ids) + @cached(max_entries=5000) + async def _get_membership_from_event_id( + self, member_event_id: str + ) -> Optional[EventIdMembership]: + raise NotImplementedError() + + @cachedList( + cached_method_name="_get_membership_from_event_id", list_name="member_event_ids" + ) async def get_membership_from_event_ids( self, member_event_ids: Iterable[str] - ) -> List[dict]: - """Get user_id and membership of a set of event IDs.""" + ) -> Dict[str, Optional[EventIdMembership]]: + """Get user_id and membership of a set of event IDs. + + Returns: + Mapping from event ID to `EventIdMembership` if the event is a + membership event, otherwise the value is None. + """ - return await self.db_pool.simple_select_many_batch( + rows = await self.db_pool.simple_select_many_batch( table="room_memberships", column="event_id", iterable=member_event_ids, @@ -1015,6 +1037,13 @@ class RoomMemberWorkerStore(EventsWorkerStore): desc="get_membership_from_event_ids", ) + return { + row["event_id"]: EventIdMembership( + membership=row["membership"], user_id=row["user_id"] + ) + for row in rows + } + async def is_local_host_in_room_ignoring_users( self, room_id: str, ignore_users: Collection[str] ) -> bool: |