summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
authorDirk Klimpel <5740567+dklimpel@users.noreply.github.com>2022-05-17 16:29:06 +0200
committerGitHub <noreply@github.com>2022-05-17 15:29:06 +0100
commit6edefef60289cc54e17fd6af838eb66c4973f5f5 (patch)
treefd254e6d378e0f877c39826b89e8db47785abcdb /synapse/storage
parentAdd a new room version for MSC3787's knock+restricted join rule (#12623) (diff)
downloadsynapse-6edefef60289cc54e17fd6af838eb66c4973f5f5.tar.xz
Add some type hints to datastore (#12717)
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/databases/main/__init__.py8
-rw-r--r--synapse/storage/databases/main/metrics.py56
-rw-r--r--synapse/storage/databases/main/push_rule.py184
-rw-r--r--synapse/storage/databases/main/roommember.py126
4 files changed, 229 insertions, 145 deletions
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 5895b89202..d545a1c002 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -26,11 +26,7 @@ from synapse.storage.database import (
 from synapse.storage.databases.main.stats import UserSortOrder
 from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
 from synapse.storage.types import Cursor
-from synapse.storage.util.id_generators import (
-    IdGenerator,
-    MultiWriterIdGenerator,
-    StreamIdGenerator,
-)
+from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
 from synapse.types import JsonDict, get_domain_from_id
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
@@ -155,8 +151,6 @@ class DataStore(
             ],
         )
 
-        self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
-        self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
         self._group_updates_id_gen = StreamIdGenerator(
             db_conn, "local_group_updates", "stream_id"
         )
diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py
index d03555a585..14294a0bb8 100644
--- a/synapse/storage/databases/main/metrics.py
+++ b/synapse/storage/databases/main/metrics.py
@@ -14,16 +14,19 @@
 import calendar
 import logging
 import time
-from typing import TYPE_CHECKING, Dict
+from typing import TYPE_CHECKING, Dict, List, Tuple, cast
 
 from synapse.metrics import GaugeBucketCollector
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.storage.databases.main.event_push_actions import (
     EventPushActionsWorkerStore,
 )
-from synapse.storage.types import Cursor
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -73,7 +76,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
 
     @wrap_as_background_process("read_forward_extremities")
     async def _read_forward_extremities(self) -> None:
-        def fetch(txn):
+        def fetch(txn: LoggingTransaction) -> List[Tuple[int, int]]:
             txn.execute(
                 """
                 SELECT t1.c, t2.c
@@ -86,7 +89,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
                 ) t2 ON t1.room_id = t2.room_id
                 """
             )
-            return txn.fetchall()
+            return cast(List[Tuple[int, int]], txn.fetchall())
 
         res = await self.db_pool.runInteraction("read_forward_extremities", fetch)
 
@@ -104,20 +107,20 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
         call to this function, it will return None.
         """
 
-        def _count_messages(txn):
+        def _count_messages(txn: LoggingTransaction) -> int:
             sql = """
                 SELECT COUNT(*) FROM events
                 WHERE type = 'm.room.encrypted'
                 AND stream_ordering > ?
             """
             txn.execute(sql, (self.stream_ordering_day_ago,))
-            (count,) = txn.fetchone()
+            (count,) = cast(Tuple[int], txn.fetchone())
             return count
 
         return await self.db_pool.runInteraction("count_e2ee_messages", _count_messages)
 
     async def count_daily_sent_e2ee_messages(self) -> int:
-        def _count_messages(txn):
+        def _count_messages(txn: LoggingTransaction) -> int:
             # This is good enough as if you have silly characters in your own
             # hostname then that's your own fault.
             like_clause = "%:" + self.hs.hostname
@@ -130,7 +133,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
             """
 
             txn.execute(sql, (like_clause, self.stream_ordering_day_ago))
-            (count,) = txn.fetchone()
+            (count,) = cast(Tuple[int], txn.fetchone())
             return count
 
         return await self.db_pool.runInteraction(
@@ -138,14 +141,14 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
         )
 
     async def count_daily_active_e2ee_rooms(self) -> int:
-        def _count(txn):
+        def _count(txn: LoggingTransaction) -> int:
             sql = """
                 SELECT COUNT(DISTINCT room_id) FROM events
                 WHERE type = 'm.room.encrypted'
                 AND stream_ordering > ?
             """
             txn.execute(sql, (self.stream_ordering_day_ago,))
-            (count,) = txn.fetchone()
+            (count,) = cast(Tuple[int], txn.fetchone())
             return count
 
         return await self.db_pool.runInteraction(
@@ -160,20 +163,20 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
         call to this function, it will return None.
         """
 
-        def _count_messages(txn):
+        def _count_messages(txn: LoggingTransaction) -> int:
             sql = """
                 SELECT COUNT(*) FROM events
                 WHERE type = 'm.room.message'
                 AND stream_ordering > ?
             """
             txn.execute(sql, (self.stream_ordering_day_ago,))
-            (count,) = txn.fetchone()
+            (count,) = cast(Tuple[int], txn.fetchone())
             return count
 
         return await self.db_pool.runInteraction("count_messages", _count_messages)
 
     async def count_daily_sent_messages(self) -> int:
-        def _count_messages(txn):
+        def _count_messages(txn: LoggingTransaction) -> int:
             # This is good enough as if you have silly characters in your own
             # hostname then that's your own fault.
             like_clause = "%:" + self.hs.hostname
@@ -186,7 +189,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
             """
 
             txn.execute(sql, (like_clause, self.stream_ordering_day_ago))
-            (count,) = txn.fetchone()
+            (count,) = cast(Tuple[int], txn.fetchone())
             return count
 
         return await self.db_pool.runInteraction(
@@ -194,14 +197,14 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
         )
 
     async def count_daily_active_rooms(self) -> int:
-        def _count(txn):
+        def _count(txn: LoggingTransaction) -> int:
             sql = """
                 SELECT COUNT(DISTINCT room_id) FROM events
                 WHERE type = 'm.room.message'
                 AND stream_ordering > ?
             """
             txn.execute(sql, (self.stream_ordering_day_ago,))
-            (count,) = txn.fetchone()
+            (count,) = cast(Tuple[int], txn.fetchone())
             return count
 
         return await self.db_pool.runInteraction("count_daily_active_rooms", _count)
@@ -227,7 +230,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
             "count_monthly_users", self._count_users, thirty_days_ago
         )
 
-    def _count_users(self, txn: Cursor, time_from: int) -> int:
+    def _count_users(self, txn: LoggingTransaction, time_from: int) -> int:
         """
         Returns number of users seen in the past time_from period
         """
@@ -242,7 +245,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
         # Mypy knows that fetchone() might return None if there are no rows.
         # We know better: "SELECT COUNT(...) FROM ..." without any GROUP BY always
         # returns exactly one row.
-        (count,) = txn.fetchone()  # type: ignore[misc]
+        (count,) = cast(Tuple[int], txn.fetchone())
         return count
 
     async def count_r30_users(self) -> Dict[str, int]:
@@ -256,7 +259,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
              A mapping of counts globally as well as broken out by platform.
         """
 
-        def _count_r30_users(txn):
+        def _count_r30_users(txn: LoggingTransaction) -> Dict[str, int]:
             thirty_days_in_secs = 86400 * 30
             now = int(self._clock.time())
             thirty_days_ago_in_secs = now - thirty_days_in_secs
@@ -321,7 +324,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
 
             txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
 
-            (count,) = txn.fetchone()
+            (count,) = cast(Tuple[int], txn.fetchone())
             results["all"] = count
 
             return results
@@ -348,7 +351,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
               - "web" (any web application -- it's not possible to distinguish Element Web here)
         """
 
-        def _count_r30v2_users(txn):
+        def _count_r30v2_users(txn: LoggingTransaction) -> Dict[str, int]:
             thirty_days_in_secs = 86400 * 30
             now = int(self._clock.time())
             sixty_days_ago_in_secs = now - 2 * thirty_days_in_secs
@@ -445,11 +448,8 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
                     thirty_days_in_secs * 1000,
                 ),
             )
-            row = txn.fetchone()
-            if row is None:
-                results["all"] = 0
-            else:
-                results["all"] = row[0]
+            (count,) = cast(Tuple[int], txn.fetchone())
+            results["all"] = count
 
             return results
 
@@ -471,7 +471,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
         Generates daily visit data for use in cohort/ retention analysis
         """
 
-        def _generate_user_daily_visits(txn):
+        def _generate_user_daily_visits(txn: LoggingTransaction) -> None:
             logger.info("Calling _generate_user_daily_visits")
             today_start = self._get_start_of_day()
             a_day_in_milliseconds = 24 * 60 * 60 * 1000
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 4ed913e248..0e2855fb44 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -14,14 +14,18 @@
 # limitations under the License.
 import abc
 import logging
-from typing import TYPE_CHECKING, Dict, List, Tuple, Union
+from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Tuple, Union, cast
 
 from synapse.api.errors import StoreError
 from synapse.config.homeserver import ExperimentalConfig
 from synapse.push.baserules import list_with_base_rules
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.databases.main.pusher import PusherWorkerStore
@@ -30,9 +34,12 @@ from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
 from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
 from synapse.storage.util.id_generators import (
+    AbstractStreamIdGenerator,
     AbstractStreamIdTracker,
+    IdGenerator,
     StreamIdGenerator,
 )
+from synapse.types import JsonDict
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -57,7 +64,11 @@ def _is_experimental_rule_enabled(
     return True
 
 
-def _load_rules(rawrules, enabled_map, experimental_config: ExperimentalConfig):
+def _load_rules(
+    rawrules: List[JsonDict],
+    enabled_map: Dict[str, bool],
+    experimental_config: ExperimentalConfig,
+) -> List[JsonDict]:
     ruleslist = []
     for rawrule in rawrules:
         rule = dict(rawrule)
@@ -137,7 +148,7 @@ class PushRulesWorkerStore(
         )
 
     @abc.abstractmethod
-    def get_max_push_rules_stream_id(self):
+    def get_max_push_rules_stream_id(self) -> int:
         """Get the position of the push rules stream.
 
         Returns:
@@ -146,7 +157,7 @@ class PushRulesWorkerStore(
         raise NotImplementedError()
 
     @cached(max_entries=5000)
-    async def get_push_rules_for_user(self, user_id):
+    async def get_push_rules_for_user(self, user_id: str) -> List[JsonDict]:
         rows = await self.db_pool.simple_select_list(
             table="push_rules",
             keyvalues={"user_name": user_id},
@@ -168,7 +179,7 @@ class PushRulesWorkerStore(
         return _load_rules(rows, enabled_map, self.hs.config.experimental)
 
     @cached(max_entries=5000)
-    async def get_push_rules_enabled_for_user(self, user_id) -> Dict[str, bool]:
+    async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]:
         results = await self.db_pool.simple_select_list(
             table="push_rules_enable",
             keyvalues={"user_name": user_id},
@@ -184,13 +195,13 @@ class PushRulesWorkerStore(
             return False
         else:
 
-            def have_push_rules_changed_txn(txn):
+            def have_push_rules_changed_txn(txn: LoggingTransaction) -> bool:
                 sql = (
                     "SELECT COUNT(stream_id) FROM push_rules_stream"
                     " WHERE user_id = ? AND ? < stream_id"
                 )
                 txn.execute(sql, (user_id, last_id))
-                (count,) = txn.fetchone()
+                (count,) = cast(Tuple[int], txn.fetchone())
                 return bool(count)
 
             return await self.db_pool.runInteraction(
@@ -202,11 +213,13 @@ class PushRulesWorkerStore(
         list_name="user_ids",
         num_args=1,
     )
-    async def bulk_get_push_rules(self, user_ids):
+    async def bulk_get_push_rules(
+        self, user_ids: Collection[str]
+    ) -> Dict[str, List[JsonDict]]:
         if not user_ids:
             return {}
 
-        results = {user_id: [] for user_id in user_ids}
+        results: Dict[str, List[JsonDict]] = {user_id: [] for user_id in user_ids}
 
         rows = await self.db_pool.simple_select_many_batch(
             table="push_rules",
@@ -250,7 +263,7 @@ class PushRulesWorkerStore(
                 condition["pattern"] = new_room_id
 
         # Add the rule for the new room
-        await self.add_push_rule(
+        await self.add_push_rule(  # type: ignore[attr-defined]
             user_id=user_id,
             rule_id=new_rule_id,
             priority_class=rule["priority_class"],
@@ -286,11 +299,13 @@ class PushRulesWorkerStore(
         list_name="user_ids",
         num_args=1,
     )
-    async def bulk_get_push_rules_enabled(self, user_ids):
+    async def bulk_get_push_rules_enabled(
+        self, user_ids: Collection[str]
+    ) -> Dict[str, Dict[str, bool]]:
         if not user_ids:
             return {}
 
-        results = {user_id: {} for user_id in user_ids}
+        results: Dict[str, Dict[str, bool]] = {user_id: {} for user_id in user_ids}
 
         rows = await self.db_pool.simple_select_many_batch(
             table="push_rules_enable",
@@ -306,7 +321,7 @@ class PushRulesWorkerStore(
 
     async def get_all_push_rule_updates(
         self, instance_name: str, last_id: int, current_id: int, limit: int
-    ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
+    ) -> Tuple[List[Tuple[int, Tuple[str]]], int, bool]:
         """Get updates for push_rules replication stream.
 
         Args:
@@ -331,7 +346,9 @@ class PushRulesWorkerStore(
         if last_id == current_id:
             return [], current_id, False
 
-        def get_all_push_rule_updates_txn(txn):
+        def get_all_push_rule_updates_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[List[Tuple[int, Tuple[str]]], int, bool]:
             sql = """
                 SELECT stream_id, user_id
                 FROM push_rules_stream
@@ -340,7 +357,10 @@ class PushRulesWorkerStore(
                 LIMIT ?
             """
             txn.execute(sql, (last_id, current_id, limit))
-            updates = [(stream_id, (user_id,)) for stream_id, user_id in txn]
+            updates = cast(
+                List[Tuple[int, Tuple[str]]],
+                [(stream_id, (user_id,)) for stream_id, user_id in txn],
+            )
 
             limited = False
             upper_bound = current_id
@@ -356,15 +376,30 @@ class PushRulesWorkerStore(
 
 
 class PushRuleStore(PushRulesWorkerStore):
+    # Because we have write access, this will be a StreamIdGenerator
+    # (see PushRulesWorkerStore.__init__)
+    _push_rules_stream_id_gen: AbstractStreamIdGenerator
+
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
+        super().__init__(database, db_conn, hs)
+
+        self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
+        self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
+
     async def add_push_rule(
         self,
-        user_id,
-        rule_id,
-        priority_class,
-        conditions,
-        actions,
-        before=None,
-        after=None,
+        user_id: str,
+        rule_id: str,
+        priority_class: int,
+        conditions: List[Dict[str, str]],
+        actions: List[Union[JsonDict, str]],
+        before: Optional[str] = None,
+        after: Optional[str] = None,
     ) -> None:
         conditions_json = json_encoder.encode(conditions)
         actions_json = json_encoder.encode(actions)
@@ -400,17 +435,17 @@ class PushRuleStore(PushRulesWorkerStore):
 
     def _add_push_rule_relative_txn(
         self,
-        txn,
-        stream_id,
-        event_stream_ordering,
-        user_id,
-        rule_id,
-        priority_class,
-        conditions_json,
-        actions_json,
-        before,
-        after,
-    ):
+        txn: LoggingTransaction,
+        stream_id: int,
+        event_stream_ordering: int,
+        user_id: str,
+        rule_id: str,
+        priority_class: int,
+        conditions_json: str,
+        actions_json: str,
+        before: str,
+        after: str,
+    ) -> None:
         # Lock the table since otherwise we'll have annoying races between the
         # SELECT here and the UPSERT below.
         self.database_engine.lock_table(txn, "push_rules")
@@ -470,15 +505,15 @@ class PushRuleStore(PushRulesWorkerStore):
 
     def _add_push_rule_highest_priority_txn(
         self,
-        txn,
-        stream_id,
-        event_stream_ordering,
-        user_id,
-        rule_id,
-        priority_class,
-        conditions_json,
-        actions_json,
-    ):
+        txn: LoggingTransaction,
+        stream_id: int,
+        event_stream_ordering: int,
+        user_id: str,
+        rule_id: str,
+        priority_class: int,
+        conditions_json: str,
+        actions_json: str,
+    ) -> None:
         # Lock the table since otherwise we'll have annoying races between the
         # SELECT here and the UPSERT below.
         self.database_engine.lock_table(txn, "push_rules")
@@ -510,17 +545,17 @@ class PushRuleStore(PushRulesWorkerStore):
 
     def _upsert_push_rule_txn(
         self,
-        txn,
-        stream_id,
-        event_stream_ordering,
-        user_id,
-        rule_id,
-        priority_class,
-        priority,
-        conditions_json,
-        actions_json,
-        update_stream=True,
-    ):
+        txn: LoggingTransaction,
+        stream_id: int,
+        event_stream_ordering: int,
+        user_id: str,
+        rule_id: str,
+        priority_class: int,
+        priority: int,
+        conditions_json: str,
+        actions_json: str,
+        update_stream: bool = True,
+    ) -> None:
         """Specialised version of simple_upsert_txn that picks a push_rule_id
         using the _push_rule_id_gen if it needs to insert the rule. It assumes
         that the "push_rules" table is locked"""
@@ -600,7 +635,11 @@ class PushRuleStore(PushRulesWorkerStore):
             rule_id: The rule_id of the rule to be deleted
         """
 
-        def delete_push_rule_txn(txn, stream_id, event_stream_ordering):
+        def delete_push_rule_txn(
+            txn: LoggingTransaction,
+            stream_id: int,
+            event_stream_ordering: int,
+        ) -> None:
             # we don't use simple_delete_one_txn because that would fail if the
             # user did not have a push_rule_enable row.
             self.db_pool.simple_delete_txn(
@@ -661,14 +700,14 @@ class PushRuleStore(PushRulesWorkerStore):
 
     def _set_push_rule_enabled_txn(
         self,
-        txn,
-        stream_id,
-        event_stream_ordering,
-        user_id,
-        rule_id,
-        enabled,
-        is_default_rule,
-    ):
+        txn: LoggingTransaction,
+        stream_id: int,
+        event_stream_ordering: int,
+        user_id: str,
+        rule_id: str,
+        enabled: bool,
+        is_default_rule: bool,
+    ) -> None:
         new_id = self._push_rules_enable_id_gen.get_next()
 
         if not is_default_rule:
@@ -740,7 +779,11 @@ class PushRuleStore(PushRulesWorkerStore):
         """
         actions_json = json_encoder.encode(actions)
 
-        def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering):
+        def set_push_rule_actions_txn(
+            txn: LoggingTransaction,
+            stream_id: int,
+            event_stream_ordering: int,
+        ) -> None:
             if is_default_rule:
                 # Add a dummy rule to the rules table with the user specified
                 # actions.
@@ -794,8 +837,15 @@ class PushRuleStore(PushRulesWorkerStore):
             )
 
     def _insert_push_rules_update_txn(
-        self, txn, stream_id, event_stream_ordering, user_id, rule_id, op, data=None
-    ):
+        self,
+        txn: LoggingTransaction,
+        stream_id: int,
+        event_stream_ordering: int,
+        user_id: str,
+        rule_id: str,
+        op: str,
+        data: Optional[JsonDict] = None,
+    ) -> None:
         values = {
             "stream_id": stream_id,
             "event_stream_ordering": event_stream_ordering,
@@ -814,5 +864,5 @@ class PushRuleStore(PushRulesWorkerStore):
             self.push_rules_stream_cache.entity_has_changed, user_id, stream_id
         )
 
-    def get_max_push_rules_stream_id(self):
+    def get_max_push_rules_stream_id(self) -> int:
         return self._push_rules_stream_id_gen.get_current_token()
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 48e83592e7..608d40dfa1 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -37,7 +37,12 @@ from synapse.metrics.background_process_metrics import (
     wrap_as_background_process,
 )
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
+from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.engines import Sqlite3Engine
 from synapse.storage.roommember import (
@@ -46,7 +51,7 @@ from synapse.storage.roommember import (
     ProfileInfo,
     RoomsForUser,
 )
-from synapse.types import PersistedEventPosition, get_domain_from_id
+from synapse.types import JsonDict, PersistedEventPosition, StateMap, get_domain_from_id
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches import intern_string
 from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
@@ -115,7 +120,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             )
 
     @wrap_as_background_process("_count_known_servers")
-    async def _count_known_servers(self):
+    async def _count_known_servers(self) -> int:
         """
         Count the servers that this server knows about.
 
@@ -123,7 +128,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         `synapse_federation_known_servers` LaterGauge to collect.
         """
 
-        def _transact(txn):
+        def _transact(txn: LoggingTransaction) -> int:
             if isinstance(self.database_engine, Sqlite3Engine):
                 query = """
                     SELECT COUNT(DISTINCT substr(out.user_id, pos+1))
@@ -150,7 +155,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         self._known_servers_count = max([count, 1])
         return self._known_servers_count
 
-    def _check_safe_current_state_events_membership_updated_txn(self, txn):
+    def _check_safe_current_state_events_membership_updated_txn(
+        self, txn: LoggingTransaction
+    ) -> None:
         """Checks if it is safe to assume the new current_state_events
         membership column is up to date
         """
@@ -182,7 +189,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             "get_users_in_room", self.get_users_in_room_txn, room_id
         )
 
-    def get_users_in_room_txn(self, txn, room_id: str) -> List[str]:
+    def get_users_in_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[str]:
         # If we can assume current_state_events.membership is up to date
         # then we can avoid a join, which is a Very Good Thing given how
         # frequently this function gets called.
@@ -222,7 +229,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             A mapping from user ID to ProfileInfo.
         """
 
-        def _get_users_in_room_with_profiles(txn) -> Dict[str, ProfileInfo]:
+        def _get_users_in_room_with_profiles(
+            txn: LoggingTransaction,
+        ) -> Dict[str, ProfileInfo]:
             sql = """
                 SELECT state_key, display_name, avatar_url FROM room_memberships as m
                 INNER JOIN current_state_events as c
@@ -250,7 +259,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             dict of membership states, pointing to a MemberSummary named tuple.
         """
 
-        def _get_room_summary_txn(txn):
+        def _get_room_summary_txn(
+            txn: LoggingTransaction,
+        ) -> Dict[str, MemberSummary]:
             # first get counts.
             # We do this all in one transaction to keep the cache small.
             # FIXME: get rid of this when we have room_stats
@@ -279,7 +290,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
                 """
 
             txn.execute(sql, (room_id,))
-            res = {}
+            res: Dict[str, MemberSummary] = {}
             for count, membership in txn:
                 res.setdefault(membership, MemberSummary([], count))
 
@@ -400,7 +411,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
     def _get_rooms_for_local_user_where_membership_is_txn(
         self,
-        txn,
+        txn: LoggingTransaction,
         user_id: str,
         membership_list: List[str],
     ) -> List[RoomsForUser]:
@@ -488,7 +499,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         )
 
     def _get_rooms_for_user_with_stream_ordering_txn(
-        self, txn, user_id: str
+        self, txn: LoggingTransaction, user_id: str
     ) -> FrozenSet[GetRoomsForUserWithStreamOrdering]:
         # We use `current_state_events` here and not `local_current_membership`
         # as a) this gets called with remote users and b) this only gets called
@@ -542,7 +553,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         )
 
     def _get_rooms_for_users_with_stream_ordering_txn(
-        self, txn, user_ids: Collection[str]
+        self, txn: LoggingTransaction, user_ids: Collection[str]
     ) -> Dict[str, FrozenSet[GetRoomsForUserWithStreamOrdering]]:
 
         clause, args = make_in_list_sql_clause(
@@ -575,7 +586,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
         txn.execute(sql, [Membership.JOIN] + args)
 
-        result = {user_id: set() for user_id in user_ids}
+        result: Dict[str, Set[GetRoomsForUserWithStreamOrdering]] = {
+            user_id: set() for user_id in user_ids
+        }
         for user_id, room_id, instance, stream_id in txn:
             result[user_id].add(
                 GetRoomsForUserWithStreamOrdering(
@@ -595,7 +608,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         if not user_ids:
             return set()
 
-        def _get_users_server_still_shares_room_with_txn(txn):
+        def _get_users_server_still_shares_room_with_txn(
+            txn: LoggingTransaction,
+        ) -> Set[str]:
             sql = """
                 SELECT state_key FROM current_state_events
                 WHERE
@@ -657,7 +672,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
     async def get_joined_users_from_context(
         self, event: EventBase, context: EventContext
     ) -> Dict[str, ProfileInfo]:
-        state_group = context.state_group
+        state_group: Union[object, int] = context.state_group
         if not state_group:
             # If state_group is None it means it has yet to be assigned a
             # state group, i.e. we need to make sure that calls with a state_group
@@ -666,14 +681,16 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             state_group = object()
 
         current_state_ids = await context.get_current_state_ids()
+        assert current_state_ids is not None
+        assert state_group is not None
         return await self._get_joined_users_from_context(
             event.room_id, state_group, current_state_ids, event=event, context=context
         )
 
     async def get_joined_users_from_state(
-        self, room_id, state_entry
+        self, room_id: str, state_entry: "_StateCacheEntry"
     ) -> Dict[str, ProfileInfo]:
-        state_group = state_entry.state_group
+        state_group: Union[object, int] = state_entry.state_group
         if not state_group:
             # If state_group is None it means it has yet to be assigned a
             # state group, i.e. we need to make sure that calls with a state_group
@@ -681,6 +698,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             # To do this we set the state_group to a new object as object() != object()
             state_group = object()
 
+        assert state_group is not None
         with Measure(self._clock, "get_joined_users_from_state"):
             return await self._get_joined_users_from_context(
                 room_id, state_group, state_entry.state, context=state_entry
@@ -689,12 +707,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
     @cached(num_args=2, cache_context=True, iterable=True, max_entries=100000)
     async def _get_joined_users_from_context(
         self,
-        room_id,
-        state_group,
-        current_state_ids,
-        cache_context,
-        event=None,
-        context=None,
+        room_id: str,
+        state_group: Union[object, int],
+        current_state_ids: StateMap[str],
+        cache_context: _CacheContext,
+        event: Optional[EventBase] = None,
+        context: Optional[Union[EventContext, "_StateCacheEntry"]] = None,
     ) -> Dict[str, ProfileInfo]:
         # We don't use `state_group`, it's there so that we can cache based
         # on it. However, it's important that it's never None, since two current_states
@@ -765,14 +783,18 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         return users_in_room
 
     @cached(max_entries=10000)
-    def _get_joined_profile_from_event_id(self, event_id):
+    def _get_joined_profile_from_event_id(
+        self, event_id: str
+    ) -> Optional[Tuple[str, ProfileInfo]]:
         raise NotImplementedError()
 
     @cachedList(
         cached_method_name="_get_joined_profile_from_event_id",
         list_name="event_ids",
     )
-    async def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]):
+    async def _get_joined_profiles_from_event_ids(
+        self, event_ids: Iterable[str]
+    ) -> Dict[str, Optional[Tuple[str, ProfileInfo]]]:
         """For given set of member event_ids check if they point to a join
         event and if so return the associated user and profile info.
 
@@ -780,8 +802,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             event_ids: The member event IDs to lookup
 
         Returns:
-            dict[str, Tuple[str, ProfileInfo]|None]: Map from event ID
-            to `user_id` and ProfileInfo (or None if not join event).
+            Map from event ID to `user_id` and ProfileInfo (or None if not join event).
         """
 
         rows = await self.db_pool.simple_select_many_batch(
@@ -847,8 +868,10 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
         return True
 
-    async def get_joined_hosts(self, room_id: str, state_entry):
-        state_group = state_entry.state_group
+    async def get_joined_hosts(
+        self, room_id: str, state_entry: "_StateCacheEntry"
+    ) -> FrozenSet[str]:
+        state_group: Union[object, int] = state_entry.state_group
         if not state_group:
             # If state_group is None it means it has yet to be assigned a
             # state group, i.e. we need to make sure that calls with a state_group
@@ -856,6 +879,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             # To do this we set the state_group to a new object as object() != object()
             state_group = object()
 
+        assert state_group is not None
         with Measure(self._clock, "get_joined_hosts"):
             return await self._get_joined_hosts(
                 room_id, state_group, state_entry=state_entry
@@ -863,7 +887,10 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
     @cached(num_args=2, max_entries=10000, iterable=True)
     async def _get_joined_hosts(
-        self, room_id: str, state_group: int, state_entry: "_StateCacheEntry"
+        self,
+        room_id: str,
+        state_group: Union[object, int],
+        state_entry: "_StateCacheEntry",
     ) -> FrozenSet[str]:
         # We don't use `state_group`, it's there so that we can cache based on
         # it. However, its important that its never None, since two
@@ -881,7 +908,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         # `get_joined_hosts` is called with the "current" state group for the
         # room, and so consecutive calls will be for consecutive state groups
         # which point to the previous state group.
-        cache = await self._get_joined_hosts_cache(room_id)
+        cache = await self._get_joined_hosts_cache(room_id)  # type: ignore[misc]
 
         # If the state group in the cache matches, we already have the data we need.
         if state_entry.state_group == cache.state_group:
@@ -897,6 +924,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             elif state_entry.prev_group == cache.state_group:
                 # The cached work is for the previous state group, so we work out
                 # the delta.
+                assert state_entry.delta_ids is not None
                 for (typ, state_key), event_id in state_entry.delta_ids.items():
                     if typ != EventTypes.Member:
                         continue
@@ -942,7 +970,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
         Returns False if they have since re-joined."""
 
-        def f(txn):
+        def f(txn: LoggingTransaction) -> int:
             sql = (
                 "SELECT"
                 "  COUNT(*)"
@@ -973,7 +1001,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             The forgotten rooms.
         """
 
-        def _get_forgotten_rooms_for_user_txn(txn):
+        def _get_forgotten_rooms_for_user_txn(txn: LoggingTransaction) -> Set[str]:
             # This is a slightly convoluted query that first looks up all rooms
             # that the user has forgotten in the past, then rechecks that list
             # to see if any have subsequently been updated. This is done so that
@@ -1076,7 +1104,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             clause,
         )
 
-        def _is_local_host_in_room_ignoring_users_txn(txn):
+        def _is_local_host_in_room_ignoring_users_txn(
+            txn: LoggingTransaction,
+        ) -> bool:
             txn.execute(sql, (room_id, Membership.JOIN, *args))
 
             return bool(txn.fetchone())
@@ -1110,15 +1140,17 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
             where_clause="forgotten = 1",
         )
 
-    async def _background_add_membership_profile(self, progress, batch_size):
+    async def _background_add_membership_profile(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
         target_min_stream_id = progress.get(
-            "target_min_stream_id_inclusive", self._min_stream_order_on_start
+            "target_min_stream_id_inclusive", self._min_stream_order_on_start  # type: ignore[attr-defined]
         )
         max_stream_id = progress.get(
-            "max_stream_id_exclusive", self._stream_order_on_start + 1
+            "max_stream_id_exclusive", self._stream_order_on_start + 1  # type: ignore[attr-defined]
         )
 
-        def add_membership_profile_txn(txn):
+        def add_membership_profile_txn(txn: LoggingTransaction) -> int:
             sql = """
                 SELECT stream_ordering, event_id, events.room_id, event_json.json
                 FROM events
@@ -1182,13 +1214,17 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
 
         return result
 
-    async def _background_current_state_membership(self, progress, batch_size):
+    async def _background_current_state_membership(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
         """Update the new membership column on current_state_events.
 
         This works by iterating over all rooms in alphebetical order.
         """
 
-        def _background_current_state_membership_txn(txn, last_processed_room):
+        def _background_current_state_membership_txn(
+            txn: LoggingTransaction, last_processed_room: str
+        ) -> Tuple[int, bool]:
             processed = 0
             while processed < batch_size:
                 txn.execute(
@@ -1242,7 +1278,11 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
         return row_count
 
 
-class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
+class RoomMemberStore(
+    RoomMemberWorkerStore,
+    RoomMemberBackgroundUpdateStore,
+    CacheInvalidationWorkerStore,
+):
     def __init__(
         self,
         database: DatabasePool,
@@ -1254,7 +1294,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
     async def forget(self, user_id: str, room_id: str) -> None:
         """Indicate that user_id wishes to discard history for room_id."""
 
-        def f(txn):
+        def f(txn: LoggingTransaction) -> None:
             sql = (
                 "UPDATE"
                 "  room_memberships"
@@ -1288,5 +1328,5 @@ class _JoinedHostsCache:
     # equal to anything else).
     state_group: Union[object, int] = attr.Factory(object)
 
-    def __len__(self):
+    def __len__(self) -> int:
         return sum(len(v) for v in self.hosts_to_joined_users.values())