summary refs log tree commit diff
path: root/synapse/storage/databases/main/roommember.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/roommember.py')
-rw-r--r--synapse/storage/databases/main/roommember.py126
1 files changed, 83 insertions, 43 deletions
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 48e83592e7..608d40dfa1 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -37,7 +37,12 @@ from synapse.metrics.background_process_metrics import (
     wrap_as_background_process,
 )
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
+from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.engines import Sqlite3Engine
 from synapse.storage.roommember import (
@@ -46,7 +51,7 @@ from synapse.storage.roommember import (
     ProfileInfo,
     RoomsForUser,
 )
-from synapse.types import PersistedEventPosition, get_domain_from_id
+from synapse.types import JsonDict, PersistedEventPosition, StateMap, get_domain_from_id
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches import intern_string
 from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
@@ -115,7 +120,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             )
 
     @wrap_as_background_process("_count_known_servers")
-    async def _count_known_servers(self):
+    async def _count_known_servers(self) -> int:
         """
         Count the servers that this server knows about.
 
@@ -123,7 +128,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         `synapse_federation_known_servers` LaterGauge to collect.
         """
 
-        def _transact(txn):
+        def _transact(txn: LoggingTransaction) -> int:
             if isinstance(self.database_engine, Sqlite3Engine):
                 query = """
                     SELECT COUNT(DISTINCT substr(out.user_id, pos+1))
@@ -150,7 +155,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         self._known_servers_count = max([count, 1])
         return self._known_servers_count
 
-    def _check_safe_current_state_events_membership_updated_txn(self, txn):
+    def _check_safe_current_state_events_membership_updated_txn(
+        self, txn: LoggingTransaction
+    ) -> None:
         """Checks if it is safe to assume the new current_state_events
         membership column is up to date
         """
@@ -182,7 +189,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             "get_users_in_room", self.get_users_in_room_txn, room_id
         )
 
-    def get_users_in_room_txn(self, txn, room_id: str) -> List[str]:
+    def get_users_in_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[str]:
         # If we can assume current_state_events.membership is up to date
         # then we can avoid a join, which is a Very Good Thing given how
         # frequently this function gets called.
@@ -222,7 +229,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             A mapping from user ID to ProfileInfo.
         """
 
-        def _get_users_in_room_with_profiles(txn) -> Dict[str, ProfileInfo]:
+        def _get_users_in_room_with_profiles(
+            txn: LoggingTransaction,
+        ) -> Dict[str, ProfileInfo]:
             sql = """
                 SELECT state_key, display_name, avatar_url FROM room_memberships as m
                 INNER JOIN current_state_events as c
@@ -250,7 +259,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             dict of membership states, pointing to a MemberSummary named tuple.
         """
 
-        def _get_room_summary_txn(txn):
+        def _get_room_summary_txn(
+            txn: LoggingTransaction,
+        ) -> Dict[str, MemberSummary]:
             # first get counts.
             # We do this all in one transaction to keep the cache small.
             # FIXME: get rid of this when we have room_stats
@@ -279,7 +290,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
                 """
 
             txn.execute(sql, (room_id,))
-            res = {}
+            res: Dict[str, MemberSummary] = {}
             for count, membership in txn:
                 res.setdefault(membership, MemberSummary([], count))
 
@@ -400,7 +411,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
     def _get_rooms_for_local_user_where_membership_is_txn(
         self,
-        txn,
+        txn: LoggingTransaction,
         user_id: str,
         membership_list: List[str],
     ) -> List[RoomsForUser]:
@@ -488,7 +499,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         )
 
     def _get_rooms_for_user_with_stream_ordering_txn(
-        self, txn, user_id: str
+        self, txn: LoggingTransaction, user_id: str
     ) -> FrozenSet[GetRoomsForUserWithStreamOrdering]:
         # We use `current_state_events` here and not `local_current_membership`
         # as a) this gets called with remote users and b) this only gets called
@@ -542,7 +553,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         )
 
     def _get_rooms_for_users_with_stream_ordering_txn(
-        self, txn, user_ids: Collection[str]
+        self, txn: LoggingTransaction, user_ids: Collection[str]
     ) -> Dict[str, FrozenSet[GetRoomsForUserWithStreamOrdering]]:
 
         clause, args = make_in_list_sql_clause(
@@ -575,7 +586,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
         txn.execute(sql, [Membership.JOIN] + args)
 
-        result = {user_id: set() for user_id in user_ids}
+        result: Dict[str, Set[GetRoomsForUserWithStreamOrdering]] = {
+            user_id: set() for user_id in user_ids
+        }
         for user_id, room_id, instance, stream_id in txn:
             result[user_id].add(
                 GetRoomsForUserWithStreamOrdering(
@@ -595,7 +608,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         if not user_ids:
             return set()
 
-        def _get_users_server_still_shares_room_with_txn(txn):
+        def _get_users_server_still_shares_room_with_txn(
+            txn: LoggingTransaction,
+        ) -> Set[str]:
             sql = """
                 SELECT state_key FROM current_state_events
                 WHERE
@@ -657,7 +672,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
     async def get_joined_users_from_context(
         self, event: EventBase, context: EventContext
     ) -> Dict[str, ProfileInfo]:
-        state_group = context.state_group
+        state_group: Union[object, int] = context.state_group
         if not state_group:
             # If state_group is None it means it has yet to be assigned a
             # state group, i.e. we need to make sure that calls with a state_group
@@ -666,14 +681,16 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             state_group = object()
 
         current_state_ids = await context.get_current_state_ids()
+        assert current_state_ids is not None
+        assert state_group is not None
         return await self._get_joined_users_from_context(
             event.room_id, state_group, current_state_ids, event=event, context=context
         )
 
     async def get_joined_users_from_state(
-        self, room_id, state_entry
+        self, room_id: str, state_entry: "_StateCacheEntry"
     ) -> Dict[str, ProfileInfo]:
-        state_group = state_entry.state_group
+        state_group: Union[object, int] = state_entry.state_group
         if not state_group:
             # If state_group is None it means it has yet to be assigned a
             # state group, i.e. we need to make sure that calls with a state_group
@@ -681,6 +698,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             # To do this we set the state_group to a new object as object() != object()
             state_group = object()
 
+        assert state_group is not None
         with Measure(self._clock, "get_joined_users_from_state"):
             return await self._get_joined_users_from_context(
                 room_id, state_group, state_entry.state, context=state_entry
@@ -689,12 +707,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
     @cached(num_args=2, cache_context=True, iterable=True, max_entries=100000)
     async def _get_joined_users_from_context(
         self,
-        room_id,
-        state_group,
-        current_state_ids,
-        cache_context,
-        event=None,
-        context=None,
+        room_id: str,
+        state_group: Union[object, int],
+        current_state_ids: StateMap[str],
+        cache_context: _CacheContext,
+        event: Optional[EventBase] = None,
+        context: Optional[Union[EventContext, "_StateCacheEntry"]] = None,
     ) -> Dict[str, ProfileInfo]:
         # We don't use `state_group`, it's there so that we can cache based
         # on it. However, it's important that it's never None, since two current_states
@@ -765,14 +783,18 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         return users_in_room
 
     @cached(max_entries=10000)
-    def _get_joined_profile_from_event_id(self, event_id):
+    def _get_joined_profile_from_event_id(
+        self, event_id: str
+    ) -> Optional[Tuple[str, ProfileInfo]]:
         raise NotImplementedError()
 
     @cachedList(
         cached_method_name="_get_joined_profile_from_event_id",
         list_name="event_ids",
     )
-    async def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]):
+    async def _get_joined_profiles_from_event_ids(
+        self, event_ids: Iterable[str]
+    ) -> Dict[str, Optional[Tuple[str, ProfileInfo]]]:
         """For given set of member event_ids check if they point to a join
         event and if so return the associated user and profile info.
 
@@ -780,8 +802,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             event_ids: The member event IDs to lookup
 
         Returns:
-            dict[str, Tuple[str, ProfileInfo]|None]: Map from event ID
-            to `user_id` and ProfileInfo (or None if not join event).
+            Map from event ID to `user_id` and ProfileInfo (or None if not join event).
         """
 
         rows = await self.db_pool.simple_select_many_batch(
@@ -847,8 +868,10 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
         return True
 
-    async def get_joined_hosts(self, room_id: str, state_entry):
-        state_group = state_entry.state_group
+    async def get_joined_hosts(
+        self, room_id: str, state_entry: "_StateCacheEntry"
+    ) -> FrozenSet[str]:
+        state_group: Union[object, int] = state_entry.state_group
         if not state_group:
             # If state_group is None it means it has yet to be assigned a
             # state group, i.e. we need to make sure that calls with a state_group
@@ -856,6 +879,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             # To do this we set the state_group to a new object as object() != object()
             state_group = object()
 
+        assert state_group is not None
         with Measure(self._clock, "get_joined_hosts"):
             return await self._get_joined_hosts(
                 room_id, state_group, state_entry=state_entry
@@ -863,7 +887,10 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
     @cached(num_args=2, max_entries=10000, iterable=True)
     async def _get_joined_hosts(
-        self, room_id: str, state_group: int, state_entry: "_StateCacheEntry"
+        self,
+        room_id: str,
+        state_group: Union[object, int],
+        state_entry: "_StateCacheEntry",
     ) -> FrozenSet[str]:
         # We don't use `state_group`, it's there so that we can cache based on
         # it. However, its important that its never None, since two
@@ -881,7 +908,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         # `get_joined_hosts` is called with the "current" state group for the
         # room, and so consecutive calls will be for consecutive state groups
         # which point to the previous state group.
-        cache = await self._get_joined_hosts_cache(room_id)
+        cache = await self._get_joined_hosts_cache(room_id)  # type: ignore[misc]
 
         # If the state group in the cache matches, we already have the data we need.
         if state_entry.state_group == cache.state_group:
@@ -897,6 +924,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             elif state_entry.prev_group == cache.state_group:
                 # The cached work is for the previous state group, so we work out
                 # the delta.
+                assert state_entry.delta_ids is not None
                 for (typ, state_key), event_id in state_entry.delta_ids.items():
                     if typ != EventTypes.Member:
                         continue
@@ -942,7 +970,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
         Returns False if they have since re-joined."""
 
-        def f(txn):
+        def f(txn: LoggingTransaction) -> int:
             sql = (
                 "SELECT"
                 "  COUNT(*)"
@@ -973,7 +1001,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             The forgotten rooms.
         """
 
-        def _get_forgotten_rooms_for_user_txn(txn):
+        def _get_forgotten_rooms_for_user_txn(txn: LoggingTransaction) -> Set[str]:
             # This is a slightly convoluted query that first looks up all rooms
             # that the user has forgotten in the past, then rechecks that list
             # to see if any have subsequently been updated. This is done so that
@@ -1076,7 +1104,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             clause,
         )
 
-        def _is_local_host_in_room_ignoring_users_txn(txn):
+        def _is_local_host_in_room_ignoring_users_txn(
+            txn: LoggingTransaction,
+        ) -> bool:
             txn.execute(sql, (room_id, Membership.JOIN, *args))
 
             return bool(txn.fetchone())
@@ -1110,15 +1140,17 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
             where_clause="forgotten = 1",
         )
 
-    async def _background_add_membership_profile(self, progress, batch_size):
+    async def _background_add_membership_profile(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
         target_min_stream_id = progress.get(
-            "target_min_stream_id_inclusive", self._min_stream_order_on_start
+            "target_min_stream_id_inclusive", self._min_stream_order_on_start  # type: ignore[attr-defined]
         )
         max_stream_id = progress.get(
-            "max_stream_id_exclusive", self._stream_order_on_start + 1
+            "max_stream_id_exclusive", self._stream_order_on_start + 1  # type: ignore[attr-defined]
         )
 
-        def add_membership_profile_txn(txn):
+        def add_membership_profile_txn(txn: LoggingTransaction) -> int:
             sql = """
                 SELECT stream_ordering, event_id, events.room_id, event_json.json
                 FROM events
@@ -1182,13 +1214,17 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
 
         return result
 
-    async def _background_current_state_membership(self, progress, batch_size):
+    async def _background_current_state_membership(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
         """Update the new membership column on current_state_events.
 
         This works by iterating over all rooms in alphebetical order.
         """
 
-        def _background_current_state_membership_txn(txn, last_processed_room):
+        def _background_current_state_membership_txn(
+            txn: LoggingTransaction, last_processed_room: str
+        ) -> Tuple[int, bool]:
             processed = 0
             while processed < batch_size:
                 txn.execute(
@@ -1242,7 +1278,11 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
         return row_count
 
 
-class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
+class RoomMemberStore(
+    RoomMemberWorkerStore,
+    RoomMemberBackgroundUpdateStore,
+    CacheInvalidationWorkerStore,
+):
     def __init__(
         self,
         database: DatabasePool,
@@ -1254,7 +1294,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
     async def forget(self, user_id: str, room_id: str) -> None:
         """Indicate that user_id wishes to discard history for room_id."""
 
-        def f(txn):
+        def f(txn: LoggingTransaction) -> None:
             sql = (
                 "UPDATE"
                 "  room_memberships"
@@ -1288,5 +1328,5 @@ class _JoinedHostsCache:
     # equal to anything else).
     state_group: Union[object, int] = attr.Factory(object)
 
-    def __len__(self):
+    def __len__(self) -> int:
         return sum(len(v) for v in self.hosts_to_joined_users.values())