summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--changelog.d/10892.misc1
-rw-r--r--mypy.ini6
-rw-r--r--synapse/replication/slave/storage/_slaved_id_tracker.py4
-rw-r--r--synapse/replication/slave/storage/pushers.py10
-rw-r--r--synapse/storage/databases/main/pusher.py10
-rw-r--r--synapse/storage/databases/main/registration.py9
-rw-r--r--synapse/storage/util/id_generators.py143
-rw-r--r--synapse/storage/util/sequence.py6
8 files changed, 124 insertions, 65 deletions
diff --git a/changelog.d/10892.misc b/changelog.d/10892.misc
new file mode 100644
index 0000000000..c8c471159b
--- /dev/null
+++ b/changelog.d/10892.misc
@@ -0,0 +1 @@
+Add further type hints to `synapse.storage.util`.
diff --git a/mypy.ini b/mypy.ini
index e7cb80b6eb..bc2b59ff56 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -105,6 +105,12 @@ disallow_untyped_defs = True
 [mypy-synapse.state.*]
 disallow_untyped_defs = True
 
+[mypy-synapse.storage.util.*]
+disallow_untyped_defs = True
+
+[mypy-synapse.streams.*]
+disallow_untyped_defs = True
+
 [mypy-synapse.util.batching_queue]
 disallow_untyped_defs = True
 
diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py
index 2cb7489047..8c1bf9227a 100644
--- a/synapse/replication/slave/storage/_slaved_id_tracker.py
+++ b/synapse/replication/slave/storage/_slaved_id_tracker.py
@@ -13,14 +13,14 @@
 # limitations under the License.
 from typing import List, Optional, Tuple
 
-from synapse.storage.types import Connection
+from synapse.storage.database import LoggingDatabaseConnection
 from synapse.storage.util.id_generators import _load_current_id
 
 
 class SlavedIdTracker:
     def __init__(
         self,
-        db_conn: Connection,
+        db_conn: LoggingDatabaseConnection,
         table: str,
         column: str,
         extra_tables: Optional[List[Tuple[str, str]]] = None,
diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py
index 2672a2c94b..cea90c0f1b 100644
--- a/synapse/replication/slave/storage/pushers.py
+++ b/synapse/replication/slave/storage/pushers.py
@@ -15,9 +15,8 @@
 from typing import TYPE_CHECKING
 
 from synapse.replication.tcp.streams import PushersStream
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.storage.databases.main.pusher import PusherWorkerStore
-from synapse.storage.types import Connection
 
 from ._base import BaseSlavedStore
 from ._slaved_id_tracker import SlavedIdTracker
@@ -27,7 +26,12 @@ if TYPE_CHECKING:
 
 
 class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
-    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
         self._pushers_id_gen = SlavedIdTracker(  # type: ignore
             db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index a93caae8d0..b73ce53c91 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -18,8 +18,7 @@ from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional,
 
 from synapse.push import PusherConfig, ThrottleParams
 from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool
-from synapse.storage.types import Connection
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.storage.util.id_generators import StreamIdGenerator
 from synapse.types import JsonDict
 from synapse.util import json_encoder
@@ -32,7 +31,12 @@ logger = logging.getLogger(__name__)
 
 
 class PusherWorkerStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
         self._pushers_id_gen = StreamIdGenerator(
             db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 7de4ad7f9b..181841ee06 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -26,7 +26,7 @@ from synapse.metrics.background_process_metrics import wrap_as_background_proces
 from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
 from synapse.storage.databases.main.stats import StatsStore
-from synapse.storage.types import Connection, Cursor
+from synapse.storage.types import Cursor
 from synapse.storage.util.id_generators import IdGenerator
 from synapse.storage.util.sequence import build_sequence_generator
 from synapse.types import UserID, UserInfo
@@ -1775,7 +1775,12 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
 
 
 class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
-    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self._ignore_unknown_session_error = (
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 6f7cbe40f4..852bd79fee 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -16,42 +16,62 @@ import logging
 import threading
 from collections import OrderedDict
 from contextlib import contextmanager
-from typing import Dict, Iterable, List, Optional, Set, Tuple, Union
+from types import TracebackType
+from typing import (
+    AsyncContextManager,
+    ContextManager,
+    Dict,
+    Generator,
+    Generic,
+    Iterable,
+    List,
+    Optional,
+    Sequence,
+    Set,
+    Tuple,
+    Type,
+    TypeVar,
+    Union,
+    cast,
+)
 
 import attr
 from sortedcontainers import SortedSet
 
 from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.storage.types import Cursor
 from synapse.storage.util.sequence import PostgresSequenceGenerator
 
 logger = logging.getLogger(__name__)
 
 
+T = TypeVar("T")
+
+
 class IdGenerator:
-    def __init__(self, db_conn, table, column):
+    def __init__(
+        self,
+        db_conn: LoggingDatabaseConnection,
+        table: str,
+        column: str,
+    ):
         self._lock = threading.Lock()
         self._next_id = _load_current_id(db_conn, table, column)
 
-    def get_next(self):
+    def get_next(self) -> int:
         with self._lock:
             self._next_id += 1
             return self._next_id
 
 
-def _load_current_id(db_conn, table, column, step=1):
-    """
-
-    Args:
-        db_conn (object):
-        table (str):
-        column (str):
-        step (int):
-
-    Returns:
-        int
-    """
+def _load_current_id(
+    db_conn: LoggingDatabaseConnection, table: str, column: str, step: int = 1
+) -> int:
     # debug logging for https://github.com/matrix-org/synapse/issues/7968
     logger.info("initialising stream generator for %s(%s)", table, column)
     cur = db_conn.cursor(txn_name="_load_current_id")
@@ -59,7 +79,9 @@ def _load_current_id(db_conn, table, column, step=1):
         cur.execute("SELECT MAX(%s) FROM %s" % (column, table))
     else:
         cur.execute("SELECT MIN(%s) FROM %s" % (column, table))
-    (val,) = cur.fetchone()
+    result = cur.fetchone()
+    assert result is not None
+    (val,) = result
     cur.close()
     current_id = int(val) if val else step
     return (max if step > 0 else min)(current_id, step)
@@ -93,16 +115,16 @@ class StreamIdGenerator:
 
     def __init__(
         self,
-        db_conn,
-        table,
-        column,
+        db_conn: LoggingDatabaseConnection,
+        table: str,
+        column: str,
         extra_tables: Iterable[Tuple[str, str]] = (),
-        step=1,
-    ):
+        step: int = 1,
+    ) -> None:
         assert step != 0
         self._lock = threading.Lock()
-        self._step = step
-        self._current = _load_current_id(db_conn, table, column, step)
+        self._step: int = step
+        self._current: int = _load_current_id(db_conn, table, column, step)
         for table, column in extra_tables:
             self._current = (max if step > 0 else min)(
                 self._current, _load_current_id(db_conn, table, column, step)
@@ -115,7 +137,7 @@ class StreamIdGenerator:
         # The key and values are the same, but we never look at the values.
         self._unfinished_ids: OrderedDict[int, int] = OrderedDict()
 
-    def get_next(self):
+    def get_next(self) -> AsyncContextManager[int]:
         """
         Usage:
             async with stream_id_gen.get_next() as stream_id:
@@ -128,7 +150,7 @@ class StreamIdGenerator:
             self._unfinished_ids[next_id] = next_id
 
         @contextmanager
-        def manager():
+        def manager() -> Generator[int, None, None]:
             try:
                 yield next_id
             finally:
@@ -137,7 +159,7 @@ class StreamIdGenerator:
 
         return _AsyncCtxManagerWrapper(manager())
 
-    def get_next_mult(self, n):
+    def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
         """
         Usage:
             async with stream_id_gen.get_next(n) as stream_ids:
@@ -155,7 +177,7 @@ class StreamIdGenerator:
                 self._unfinished_ids[next_id] = next_id
 
         @contextmanager
-        def manager():
+        def manager() -> Generator[Sequence[int], None, None]:
             try:
                 yield next_ids
             finally:
@@ -215,7 +237,7 @@ class MultiWriterIdGenerator:
 
     def __init__(
         self,
-        db_conn,
+        db_conn: LoggingDatabaseConnection,
         db: DatabasePool,
         stream_name: str,
         instance_name: str,
@@ -223,7 +245,7 @@ class MultiWriterIdGenerator:
         sequence_name: str,
         writers: List[str],
         positive: bool = True,
-    ):
+    ) -> None:
         self._db = db
         self._stream_name = stream_name
         self._instance_name = instance_name
@@ -285,9 +307,9 @@ class MultiWriterIdGenerator:
 
     def _load_current_ids(
         self,
-        db_conn,
+        db_conn: LoggingDatabaseConnection,
         tables: List[Tuple[str, str, str]],
-    ):
+    ) -> None:
         cur = db_conn.cursor(txn_name="_load_current_ids")
 
         # Load the current positions of all writers for the stream.
@@ -335,7 +357,9 @@ class MultiWriterIdGenerator:
                     "agg": "MAX" if self._positive else "-MIN",
                 }
                 cur.execute(sql)
-                (stream_id,) = cur.fetchone()
+                result = cur.fetchone()
+                assert result is not None
+                (stream_id,) = result
 
                 max_stream_id = max(max_stream_id, stream_id)
 
@@ -354,7 +378,7 @@ class MultiWriterIdGenerator:
 
             self._persisted_upto_position = min_stream_id
 
-            rows = []
+            rows: List[Tuple[str, int]] = []
             for table, instance_column, id_column in tables:
                 sql = """
                     SELECT %(instance)s, %(id)s FROM %(table)s
@@ -367,7 +391,8 @@ class MultiWriterIdGenerator:
                 }
                 cur.execute(sql, (min_stream_id * self._return_factor,))
 
-                rows.extend(cur)
+                # Cast safety: this corresponds to the types returned by the query above.
+                rows.extend(cast(Iterable[Tuple[str, int]], cur))
 
             # Sort so that we handle rows in order for each instance.
             rows.sort()
@@ -385,13 +410,13 @@ class MultiWriterIdGenerator:
 
         cur.close()
 
-    def _load_next_id_txn(self, txn) -> int:
+    def _load_next_id_txn(self, txn: Cursor) -> int:
         return self._sequence_gen.get_next_id_txn(txn)
 
-    def _load_next_mult_id_txn(self, txn, n: int) -> List[int]:
+    def _load_next_mult_id_txn(self, txn: Cursor, n: int) -> List[int]:
         return self._sequence_gen.get_next_mult_txn(txn, n)
 
-    def get_next(self):
+    def get_next(self) -> AsyncContextManager[int]:
         """
         Usage:
             async with stream_id_gen.get_next() as stream_id:
@@ -403,9 +428,12 @@ class MultiWriterIdGenerator:
         if self._writers and self._instance_name not in self._writers:
             raise Exception("Tried to allocate stream ID on non-writer")
 
-        return _MultiWriterCtxManager(self)
+        # Cast safety: the second argument to _MultiWriterCtxManager, multiple_ids,
+        # controls the return type. If `None` or omitted, the context manager yields
+        # a single integer stream_id; otherwise it yields a list of stream_ids.
+        return cast(AsyncContextManager[int], _MultiWriterCtxManager(self))
 
-    def get_next_mult(self, n: int):
+    def get_next_mult(self, n: int) -> AsyncContextManager[List[int]]:
         """
         Usage:
             async with stream_id_gen.get_next_mult(5) as stream_ids:
@@ -417,9 +445,10 @@ class MultiWriterIdGenerator:
         if self._writers and self._instance_name not in self._writers:
             raise Exception("Tried to allocate stream ID on non-writer")
 
-        return _MultiWriterCtxManager(self, n)
+        # Cast safety: see get_next.
+        return cast(AsyncContextManager[List[int]], _MultiWriterCtxManager(self, n))
 
-    def get_next_txn(self, txn: LoggingTransaction):
+    def get_next_txn(self, txn: LoggingTransaction) -> int:
         """
         Usage:
 
@@ -457,7 +486,7 @@ class MultiWriterIdGenerator:
 
         return self._return_factor * next_id
 
-    def _mark_id_as_finished(self, next_id: int):
+    def _mark_id_as_finished(self, next_id: int) -> None:
         """The ID has finished being processed so we should advance the
         current position if possible.
         """
@@ -534,7 +563,7 @@ class MultiWriterIdGenerator:
                 for name, i in self._current_positions.items()
             }
 
-    def advance(self, instance_name: str, new_id: int):
+    def advance(self, instance_name: str, new_id: int) -> None:
         """Advance the position of the named writer to the given ID, if greater
         than existing entry.
         """
@@ -560,7 +589,7 @@ class MultiWriterIdGenerator:
         with self._lock:
             return self._return_factor * self._persisted_upto_position
 
-    def _add_persisted_position(self, new_id: int):
+    def _add_persisted_position(self, new_id: int) -> None:
         """Record that we have persisted a position.
 
         This is used to keep the `_current_positions` up to date.
@@ -606,7 +635,7 @@ class MultiWriterIdGenerator:
                 # do.
                 break
 
-    def _update_stream_positions_table_txn(self, txn: Cursor):
+    def _update_stream_positions_table_txn(self, txn: Cursor) -> None:
         """Update the `stream_positions` table with newly persisted position."""
 
         if not self._writers:
@@ -628,20 +657,25 @@ class MultiWriterIdGenerator:
         txn.execute(sql, (self._stream_name, self._instance_name, pos))
 
 
-@attr.s(slots=True)
-class _AsyncCtxManagerWrapper:
+@attr.s(frozen=True, auto_attribs=True)
+class _AsyncCtxManagerWrapper(Generic[T]):
     """Helper class to convert a plain context manager to an async one.
 
     This is mainly useful if you have a plain context manager but the interface
     requires an async one.
     """
 
-    inner = attr.ib()
+    inner: ContextManager[T]
 
-    async def __aenter__(self):
+    async def __aenter__(self) -> T:
         return self.inner.__enter__()
 
-    async def __aexit__(self, exc_type, exc, tb):
+    async def __aexit__(
+        self,
+        exc_type: Optional[Type[BaseException]],
+        exc: Optional[BaseException],
+        tb: Optional[TracebackType],
+    ) -> Optional[bool]:
         return self.inner.__exit__(exc_type, exc, tb)
 
 
@@ -671,7 +705,12 @@ class _MultiWriterCtxManager:
         else:
             return [i * self.id_gen._return_factor for i in self.stream_ids]
 
-    async def __aexit__(self, exc_type, exc, tb):
+    async def __aexit__(
+        self,
+        exc_type: Optional[Type[BaseException]],
+        exc: Optional[BaseException],
+        tb: Optional[TracebackType],
+    ) -> bool:
         for i in self.stream_ids:
             self.id_gen._mark_id_as_finished(i)
 
diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
index bb33e04fb1..75268cbe15 100644
--- a/synapse/storage/util/sequence.py
+++ b/synapse/storage/util/sequence.py
@@ -81,7 +81,7 @@ class SequenceGenerator(metaclass=abc.ABCMeta):
         id_column: str,
         stream_name: Optional[str] = None,
         positive: bool = True,
-    ):
+    ) -> None:
         """Should be called during start up to test that the current value of
         the sequence is greater than or equal to the maximum ID in the table.
 
@@ -122,7 +122,7 @@ class PostgresSequenceGenerator(SequenceGenerator):
         id_column: str,
         stream_name: Optional[str] = None,
         positive: bool = True,
-    ):
+    ) -> None:
         """See SequenceGenerator.check_consistency for docstring."""
 
         txn = db_conn.cursor(txn_name="sequence.check_consistency")
@@ -244,7 +244,7 @@ class LocalSequenceGenerator(SequenceGenerator):
         id_column: str,
         stream_name: Optional[str] = None,
         positive: bool = True,
-    ):
+    ) -> None:
         # There is nothing to do for in memory sequences
         pass