summary refs log tree commit diff
path: root/synapse/storage/databases/main/roommember.py
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2020-08-12 12:14:34 -0400
committerGitHub <noreply@github.com>2020-08-12 12:14:34 -0400
commitfbe930dad28c81a5e563ddc8683f65b8279aad52 (patch)
tree2a8a98977cf85509212373f7222f12cfc6226ad7 /synapse/storage/databases/main/roommember.py
parentConvert devices database to async/await. (#8069) (diff)
downloadsynapse-fbe930dad28c81a5e563ddc8683f65b8279aad52.tar.xz
Convert the roommember database to async/await. (#8070)
Diffstat (limited to 'synapse/storage/databases/main/roommember.py')
-rw-r--r--synapse/storage/databases/main/roommember.py263
1 files changed, 98 insertions, 165 deletions
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 7c5be251bd..b2fcfc9bfe 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -15,11 +15,13 @@
 # limitations under the License.
 
 import logging
-from typing import Iterable, List, Set
+from typing import TYPE_CHECKING, Awaitable, Iterable, List, Optional, Set
 
 from twisted.internet import defer
 
 from synapse.api.constants import EventTypes, Membership
+from synapse.events import EventBase
+from synapse.events.snapshot import EventContext
 from synapse.metrics import LaterGauge
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage._base import (
@@ -40,9 +42,12 @@ from synapse.storage.roommember import (
 from synapse.types import Collection, get_domain_from_id
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches import intern_string
-from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
+from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
 from synapse.util.metrics import Measure
 
+if TYPE_CHECKING:
+    from synapse.state import _StateCacheEntry
+
 logger = logging.getLogger(__name__)
 
 
@@ -150,12 +155,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             )
 
     @cached(max_entries=100000, iterable=True)
-    def get_users_in_room(self, room_id):
+    def get_users_in_room(self, room_id: str):
         return self.db_pool.runInteraction(
             "get_users_in_room", self.get_users_in_room_txn, room_id
         )
 
-    def get_users_in_room_txn(self, txn, room_id):
+    def get_users_in_room_txn(self, txn, room_id: str) -> List[str]:
         # If we can assume current_state_events.membership is up to date
         # then we can avoid a join, which is a Very Good Thing given how
         # frequently this function gets called.
@@ -178,11 +183,11 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         return [r[0] for r in txn]
 
     @cached(max_entries=100000)
-    def get_room_summary(self, room_id):
+    def get_room_summary(self, room_id: str):
         """ Get the details of a room roughly suitable for use by the room
         summary extension to /sync. Useful when lazy loading room members.
         Args:
-            room_id (str): The room ID to query
+            room_id: The room ID to query
         Returns:
             Deferred[dict[str, MemberSummary]:
                 dict of membership states, pointing to a MemberSummary named tuple.
@@ -261,78 +266,59 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
         return self.db_pool.runInteraction("get_room_summary", _get_room_summary_txn)
 
-    def _get_user_counts_in_room_txn(self, txn, room_id):
-        """
-        Get the user count in a room by membership.
-
-        Args:
-            room_id (str)
-            membership (Membership)
-
-        Returns:
-            Deferred[int]
-        """
-        sql = """
-        SELECT m.membership, count(*) FROM room_memberships as m
-            INNER JOIN current_state_events as c USING(event_id)
-            WHERE c.type = 'm.room.member' AND c.room_id = ?
-            GROUP BY m.membership
-        """
-
-        txn.execute(sql, (room_id,))
-        return {row[0]: row[1] for row in txn}
-
     @cached()
-    def get_invited_rooms_for_local_user(self, user_id):
-        """ Get all the rooms the *local* user is invited to
+    def get_invited_rooms_for_local_user(self, user_id: str) -> Awaitable[RoomsForUser]:
+        """Get all the rooms the *local* user is invited to.
 
         Args:
-            user_id (str): The user ID.
+            user_id: The user ID.
+
         Returns:
-            A deferred list of RoomsForUser.
+            A awaitable list of RoomsForUser.
         """
 
         return self.get_rooms_for_local_user_where_membership_is(
             user_id, [Membership.INVITE]
         )
 
-    @defer.inlineCallbacks
-    def get_invite_for_local_user_in_room(self, user_id, room_id):
-        """Gets the invite for the given *local* user and room
+    async def get_invite_for_local_user_in_room(
+        self, user_id: str, room_id: str
+    ) -> Optional[RoomsForUser]:
+        """Gets the invite for the given *local* user and room.
 
         Args:
-            user_id (str)
-            room_id (str)
+            user_id: The user ID to find the invite of.
+            room_id: The room to user was invited to.
 
         Returns:
-            Deferred: Resolves to either a RoomsForUser or None if no invite was
-                found.
+            Either a RoomsForUser or None if no invite was found.
         """
-        invites = yield self.get_invited_rooms_for_local_user(user_id)
+        invites = await self.get_invited_rooms_for_local_user(user_id)
         for invite in invites:
             if invite.room_id == room_id:
                 return invite
         return None
 
-    @defer.inlineCallbacks
-    def get_rooms_for_local_user_where_membership_is(self, user_id, membership_list):
-        """ Get all the rooms for this *local* user where the membership for this user
+    async def get_rooms_for_local_user_where_membership_is(
+        self, user_id: str, membership_list: List[str]
+    ) -> Optional[List[RoomsForUser]]:
+        """Get all the rooms for this *local* user where the membership for this user
         matches one in the membership list.
 
         Filters out forgotten rooms.
 
         Args:
-            user_id (str): The user ID.
-            membership_list (list): A list of synapse.api.constants.Membership
-            values which the user must be in.
+            user_id: The user ID.
+            membership_list: A list of synapse.api.constants.Membership
+                values which the user must be in.
 
         Returns:
-            Deferred[list[RoomsForUser]]
+            The RoomsForUser that the user matches the membership types.
         """
         if not membership_list:
-            return defer.succeed(None)
+            return None
 
-        rooms = yield self.db_pool.runInteraction(
+        rooms = await self.db_pool.runInteraction(
             "get_rooms_for_local_user_where_membership_is",
             self._get_rooms_for_local_user_where_membership_is_txn,
             user_id,
@@ -340,12 +326,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         )
 
         # Now we filter out forgotten rooms
-        forgotten_rooms = yield self.get_forgotten_rooms_for_user(user_id)
+        forgotten_rooms = await self.get_forgotten_rooms_for_user(user_id)
         return [room for room in rooms if room.room_id not in forgotten_rooms]
 
     def _get_rooms_for_local_user_where_membership_is_txn(
-        self, txn, user_id, membership_list
-    ):
+        self, txn, user_id: str, membership_list: List[str]
+    ) -> List[RoomsForUser]:
         # Paranoia check.
         if not self.hs.is_mine_id(user_id):
             raise Exception(
@@ -374,14 +360,14 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         return results
 
     @cached(max_entries=500000, iterable=True)
-    def get_rooms_for_user_with_stream_ordering(self, user_id):
+    def get_rooms_for_user_with_stream_ordering(self, user_id: str):
         """Returns a set of room_ids the user is currently joined to.
 
         If a remote user only returns rooms this server is currently
         participating in.
 
         Args:
-            user_id (str)
+            user_id
 
         Returns:
             Deferred[frozenset[GetRoomsForUserWithStreamOrdering]]: Returns
@@ -394,7 +380,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             user_id,
         )
 
-    def _get_rooms_for_user_with_stream_ordering_txn(self, txn, user_id):
+    def _get_rooms_for_user_with_stream_ordering_txn(self, txn, user_id: str):
         # We use `current_state_events` here and not `local_current_membership`
         # as a) this gets called with remote users and b) this only gets called
         # for rooms the server is participating in.
@@ -458,37 +444,39 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             _get_users_server_still_shares_room_with_txn,
         )
 
-    @defer.inlineCallbacks
-    def get_rooms_for_user(self, user_id, on_invalidate=None):
+    async def get_rooms_for_user(self, user_id: str, on_invalidate=None):
         """Returns a set of room_ids the user is currently joined to.
 
         If a remote user only returns rooms this server is currently
         participating in.
         """
-        rooms = yield self.get_rooms_for_user_with_stream_ordering(
+        rooms = await self.get_rooms_for_user_with_stream_ordering(
             user_id, on_invalidate=on_invalidate
         )
         return frozenset(r.room_id for r in rooms)
 
-    @cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True)
-    def get_users_who_share_room_with_user(self, user_id, cache_context):
+    @cached(max_entries=500000, cache_context=True, iterable=True)
+    async def get_users_who_share_room_with_user(
+        self, user_id: str, cache_context: _CacheContext
+    ) -> Set[str]:
         """Returns the set of users who share a room with `user_id`
         """
-        room_ids = yield self.get_rooms_for_user(
+        room_ids = await self.get_rooms_for_user(
             user_id, on_invalidate=cache_context.invalidate
         )
 
         user_who_share_room = set()
         for room_id in room_ids:
-            user_ids = yield self.get_users_in_room(
+            user_ids = await self.get_users_in_room(
                 room_id, on_invalidate=cache_context.invalidate
             )
             user_who_share_room.update(user_ids)
 
         return user_who_share_room
 
-    @defer.inlineCallbacks
-    def get_joined_users_from_context(self, event, context):
+    async def get_joined_users_from_context(
+        self, event: EventBase, context: EventContext
+    ):
         state_group = context.state_group
         if not state_group:
             # If state_group is None it means it has yet to be assigned a
@@ -497,14 +485,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             # To do this we set the state_group to a new object as object() != object()
             state_group = object()
 
-        current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
-        result = yield self._get_joined_users_from_context(
+        current_state_ids = await context.get_current_state_ids()
+        return await self._get_joined_users_from_context(
             event.room_id, state_group, current_state_ids, event=event, context=context
         )
-        return result
 
-    @defer.inlineCallbacks
-    def get_joined_users_from_state(self, room_id, state_entry):
+    async def get_joined_users_from_state(self, room_id, state_entry):
         state_group = state_entry.state_group
         if not state_group:
             # If state_group is None it means it has yet to be assigned a
@@ -514,16 +500,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             state_group = object()
 
         with Measure(self._clock, "get_joined_users_from_state"):
-            return (
-                yield self._get_joined_users_from_context(
-                    room_id, state_group, state_entry.state, context=state_entry
-                )
+            return await self._get_joined_users_from_context(
+                room_id, state_group, state_entry.state, context=state_entry
             )
 
-    @cachedInlineCallbacks(
-        num_args=2, cache_context=True, iterable=True, max_entries=100000
-    )
-    def _get_joined_users_from_context(
+    @cached(num_args=2, cache_context=True, iterable=True, max_entries=100000)
+    async def _get_joined_users_from_context(
         self,
         room_id,
         state_group,
@@ -535,7 +517,6 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         # We don't use `state_group`, it's there so that we can cache based
         # on it. However, it's important that it's never None, since two current_states
         # with a state_group of None are likely to be different.
-        # See bulk_get_push_rules_for_room for how we work around this.
         assert state_group is not None
 
         users_in_room = {}
@@ -588,7 +569,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
                 missing_member_event_ids.append(event_id)
 
         if missing_member_event_ids:
-            event_to_memberships = yield self._get_joined_profiles_from_event_ids(
+            event_to_memberships = await self._get_joined_profiles_from_event_ids(
                 missing_member_event_ids
             )
             users_in_room.update((row for row in event_to_memberships.values() if row))
@@ -612,12 +593,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         list_name="event_ids",
         inlineCallbacks=True,
     )
-    def _get_joined_profiles_from_event_ids(self, event_ids):
+    def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]):
         """For given set of member event_ids check if they point to a join
         event and if so return the associated user and profile info.
 
         Args:
-            event_ids (Iterable[str]): The member event IDs to lookup
+            event_ids: The member event IDs to lookup
 
         Returns:
             Deferred[dict[str, Tuple[str, ProfileInfo]|None]]: Map from event ID
@@ -644,8 +625,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             for row in rows
         }
 
-    @cachedInlineCallbacks(max_entries=10000)
-    def is_host_joined(self, room_id, host):
+    @cached(max_entries=10000)
+    async def is_host_joined(self, room_id: str, host: str) -> bool:
         if "%" in host or "_" in host:
             raise Exception("Invalid host name")
 
@@ -664,7 +645,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         # the returned user actually has the correct domain.
         like_clause = "%:" + host
 
-        rows = yield self.db_pool.execute(
+        rows = await self.db_pool.execute(
             "is_host_joined", None, sql, room_id, like_clause
         )
 
@@ -678,50 +659,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
         return True
 
-    @cachedInlineCallbacks()
-    def was_host_joined(self, room_id, host):
-        """Check whether the server is or ever was in the room.
-
-        Args:
-            room_id (str)
-            host (str)
-
-        Returns:
-            Deferred: Resolves to True if the host is/was in the room, otherwise
-            False.
-        """
-        if "%" in host or "_" in host:
-            raise Exception("Invalid host name")
-
-        sql = """
-            SELECT user_id FROM room_memberships
-            WHERE room_id = ?
-                AND user_id LIKE ?
-                AND membership = 'join'
-            LIMIT 1
-        """
-
-        # We do need to be careful to ensure that host doesn't have any wild cards
-        # in it, but we checked above for known ones and we'll check below that
-        # the returned user actually has the correct domain.
-        like_clause = "%:" + host
-
-        rows = yield self.db_pool.execute(
-            "was_host_joined", None, sql, room_id, like_clause
-        )
-
-        if not rows:
-            return False
-
-        user_id = rows[0][0]
-        if get_domain_from_id(user_id) != host:
-            # This can only happen if the host name has something funky in it
-            raise Exception("Invalid host name")
-
-        return True
-
-    @defer.inlineCallbacks
-    def get_joined_hosts(self, room_id, state_entry):
+    async def get_joined_hosts(self, room_id: str, state_entry):
         state_group = state_entry.state_group
         if not state_group:
             # If state_group is None it means it has yet to be assigned a
@@ -731,32 +669,28 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             state_group = object()
 
         with Measure(self._clock, "get_joined_hosts"):
-            return (
-                yield self._get_joined_hosts(
-                    room_id, state_group, state_entry.state, state_entry=state_entry
-                )
+            return await self._get_joined_hosts(
+                room_id, state_group, state_entry.state, state_entry=state_entry
             )
 
-    @cachedInlineCallbacks(num_args=2, max_entries=10000, iterable=True)
-    # @defer.inlineCallbacks
-    def _get_joined_hosts(self, room_id, state_group, current_state_ids, state_entry):
+    @cached(num_args=2, max_entries=10000, iterable=True)
+    async def _get_joined_hosts(
+        self, room_id, state_group, current_state_ids, state_entry
+    ):
         # We don't use `state_group`, its there so that we can cache based
         # on it. However, its important that its never None, since two current_state's
         # with a state_group of None are likely to be different.
-        # See bulk_get_push_rules_for_room for how we work around this.
         assert state_group is not None
 
-        cache = yield self._get_joined_hosts_cache(room_id)
-        joined_hosts = yield cache.get_destinations(state_entry)
-
-        return joined_hosts
+        cache = await self._get_joined_hosts_cache(room_id)
+        return await cache.get_destinations(state_entry)
 
     @cached(max_entries=10000)
-    def _get_joined_hosts_cache(self, room_id):
+    def _get_joined_hosts_cache(self, room_id: str) -> "_JoinedHostsCache":
         return _JoinedHostsCache(self, room_id)
 
-    @cachedInlineCallbacks(num_args=2)
-    def did_forget(self, user_id, room_id):
+    @cached(num_args=2)
+    async def did_forget(self, user_id: str, room_id: str) -> bool:
         """Returns whether user_id has elected to discard history for room_id.
 
         Returns False if they have since re-joined."""
@@ -778,15 +712,15 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             rows = txn.fetchall()
             return rows[0][0]
 
-        count = yield self.db_pool.runInteraction("did_forget_membership", f)
+        count = await self.db_pool.runInteraction("did_forget_membership", f)
         return count == 0
 
     @cached()
-    def get_forgotten_rooms_for_user(self, user_id):
+    def get_forgotten_rooms_for_user(self, user_id: str):
         """Gets all rooms the user has forgotten.
 
         Args:
-            user_id (str)
+            user_id
 
         Returns:
             Deferred[set[str]]
@@ -819,18 +753,17 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             "get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn
         )
 
-    @defer.inlineCallbacks
-    def get_rooms_user_has_been_in(self, user_id):
+    async def get_rooms_user_has_been_in(self, user_id: str) -> Set[str]:
         """Get all rooms that the user has ever been in.
 
         Args:
-            user_id (str)
+            user_id: The user ID to get the rooms of.
 
         Returns:
-            Deferred[set[str]]: Set of room IDs.
+            Set of room IDs.
         """
 
-        room_ids = yield self.db_pool.simple_select_onecol(
+        room_ids = await self.db_pool.simple_select_onecol(
             table="room_memberships",
             keyvalues={"membership": Membership.JOIN, "user_id": user_id},
             retcol="room_id",
@@ -905,8 +838,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
             where_clause="forgotten = 1",
         )
 
-    @defer.inlineCallbacks
-    def _background_add_membership_profile(self, progress, batch_size):
+    async def _background_add_membership_profile(self, progress, batch_size):
         target_min_stream_id = progress.get(
             "target_min_stream_id_inclusive", self._min_stream_order_on_start
         )
@@ -971,19 +903,18 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
 
             return len(rows)
 
-        result = yield self.db_pool.runInteraction(
+        result = await self.db_pool.runInteraction(
             _MEMBERSHIP_PROFILE_UPDATE_NAME, add_membership_profile_txn
         )
 
         if not result:
-            yield self.db_pool.updates._end_background_update(
+            await self.db_pool.updates._end_background_update(
                 _MEMBERSHIP_PROFILE_UPDATE_NAME
             )
 
         return result
 
-    @defer.inlineCallbacks
-    def _background_current_state_membership(self, progress, batch_size):
+    async def _background_current_state_membership(self, progress, batch_size):
         """Update the new membership column on current_state_events.
 
         This works by iterating over all rooms in alphebetical order.
@@ -1029,14 +960,14 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
         # string, which will compare before all room IDs correctly.
         last_processed_room = progress.get("last_processed_room", "")
 
-        row_count, finished = yield self.db_pool.runInteraction(
+        row_count, finished = await self.db_pool.runInteraction(
             "_background_current_state_membership_update",
             _background_current_state_membership_txn,
             last_processed_room,
         )
 
         if finished:
-            yield self.db_pool.updates._end_background_update(
+            await self.db_pool.updates._end_background_update(
                 _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME
             )
 
@@ -1047,7 +978,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
         super(RoomMemberStore, self).__init__(database, db_conn, hs)
 
-    def forget(self, user_id, room_id):
+    def forget(self, user_id: str, room_id: str):
         """Indicate that user_id wishes to discard history for room_id."""
 
         def f(txn):
@@ -1088,17 +1019,19 @@ class _JoinedHostsCache(object):
 
         self._len = 0
 
-    @defer.inlineCallbacks
-    def get_destinations(self, state_entry):
+    async def get_destinations(self, state_entry: "_StateCacheEntry") -> Set[str]:
         """Get set of destinations for a state entry
 
         Args:
-            state_entry(synapse.state._StateCacheEntry)
+            state_entry
+
+        Returns:
+            The destinations as a set.
         """
         if state_entry.state_group == self.state_group:
             return frozenset(self.hosts_to_joined_users)
 
-        with (yield self.linearizer.queue(())):
+        with (await self.linearizer.queue(())):
             if state_entry.state_group == self.state_group:
                 pass
             elif state_entry.prev_group == self.state_group:
@@ -1110,7 +1043,7 @@ class _JoinedHostsCache(object):
                     user_id = state_key
                     known_joins = self.hosts_to_joined_users.setdefault(host, set())
 
-                    event = yield self.store.get_event(event_id)
+                    event = await self.store.get_event(event_id)
                     if event.membership == Membership.JOIN:
                         known_joins.add(user_id)
                     else:
@@ -1119,7 +1052,7 @@ class _JoinedHostsCache(object):
                         if not known_joins:
                             self.hosts_to_joined_users.pop(host, None)
             else:
-                joined_users = yield self.store.get_joined_users_from_state(
+                joined_users = await self.store.get_joined_users_from_state(
                     self.room_id, state_entry
                 )