summary refs log tree commit diff
path: root/synapse/storage/_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/_base.py')
-rw-r--r--synapse/storage/_base.py36
1 files changed, 25 insertions, 11 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 2b196ded1b..a25c4093bc 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -17,14 +17,18 @@
 import logging
 import random
 from abc import ABCMeta
-from typing import Any, Optional
+from typing import TYPE_CHECKING, Any, Iterable, Optional, Union
 
 from synapse.storage.database import LoggingTransaction  # noqa: F401
 from synapse.storage.database import make_in_list_sql_clause  # noqa: F401
 from synapse.storage.database import DatabasePool
-from synapse.types import Collection, get_domain_from_id
+from synapse.storage.types import Connection
+from synapse.types import Collection, StreamToken, get_domain_from_id
 from synapse.util import json_decoder
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -36,24 +40,31 @@ class SQLBaseStore(metaclass=ABCMeta):
     per data store (and not one per physical database).
     """
 
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
         self.hs = hs
         self._clock = hs.get_clock()
         self.database_engine = database.engine
         self.db_pool = database
         self.rand = random.SystemRandom()
 
-    def process_replication_rows(self, stream_name, instance_name, token, rows):
+    def process_replication_rows(
+        self,
+        stream_name: str,
+        instance_name: str,
+        token: StreamToken,
+        rows: Iterable[Any],
+    ) -> None:
         pass
 
-    def _invalidate_state_caches(self, room_id, members_changed):
+    def _invalidate_state_caches(
+        self, room_id: str, members_changed: Iterable[str]
+    ) -> None:
         """Invalidates caches that are based on the current state, but does
         not stream invalidations down replication.
 
         Args:
-            room_id (str): Room where state changed
-            members_changed (iterable[str]): The user_ids of members that have
-                changed
+            room_id: Room where state changed
+            members_changed: The user_ids of members that have changed
         """
         for host in {get_domain_from_id(u) for u in members_changed}:
             self._attempt_to_invalidate_cache("is_host_joined", (room_id, host))
@@ -64,7 +75,7 @@ class SQLBaseStore(metaclass=ABCMeta):
 
     def _attempt_to_invalidate_cache(
         self, cache_name: str, key: Optional[Collection[Any]]
-    ):
+    ) -> None:
         """Attempts to invalidate the cache of the given name, ignoring if the
         cache doesn't exist. Mainly used for invalidating caches on workers,
         where they may not have the cache.
@@ -88,12 +99,15 @@ class SQLBaseStore(metaclass=ABCMeta):
             cache.invalidate(tuple(key))
 
 
-def db_to_json(db_content):
+def db_to_json(db_content: Union[memoryview, bytes, bytearray, str]) -> Any:
     """
     Take some data from a database row and return a JSON-decoded object.
 
     Args:
-        db_content (memoryview|buffer|bytes|bytearray|unicode)
+        db_content: The JSON-encoded contents from the database.
+
+    Returns:
+        The object decoded from JSON.
     """
     # psycopg2 on Python 3 returns memoryview objects, which we need to
     # cast to bytes to decode