summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/15035.misc1
-rw-r--r--mypy.ini3
-rw-r--r--synapse/handlers/room.py2
-rw-r--r--synapse/server.py12
-rw-r--r--synapse/storage/_base.py2
-rw-r--r--synapse/storage/database.py1
-rw-r--r--synapse/storage/databases/main/events.py2
7 files changed, 11 insertions, 12 deletions
diff --git a/changelog.d/15035.misc b/changelog.d/15035.misc
new file mode 100644
index 0000000000..93ceaeafc9
--- /dev/null
+++ b/changelog.d/15035.misc
@@ -0,0 +1 @@
+Improve type hints.
diff --git a/mypy.ini b/mypy.ini
index 3f144e61fb..57f27ba4f7 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -53,9 +53,6 @@ warn_unused_ignores = False
 [mypy-synapse.util.caches.treecache]
 disallow_untyped_defs = False
 
-[mypy-synapse.server]
-disallow_untyped_defs = False
-
 [mypy-synapse.storage.database]
 disallow_untyped_defs = False
 
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 7ba7c4ff07..0e759b8a5d 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -1076,7 +1076,7 @@ class RoomCreationHandler:
         state_map: MutableStateMap[str] = {}
         # current_state_group of last event created. Used for computing event context of
         # events to be batched
-        current_state_group = None
+        current_state_group: Optional[int] = None
 
         def create_event_dict(etype: str, content: JsonDict, **kwargs: Any) -> JsonDict:
             e = {"type": etype, "content": content}
diff --git a/synapse/server.py b/synapse/server.py
index 9d6d268f49..efc6b5f895 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -21,7 +21,7 @@
 import abc
 import functools
 import logging
-from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, TypeVar, cast
+from typing import TYPE_CHECKING, Callable, Dict, List, Optional, TypeVar, cast
 
 from twisted.internet.interfaces import IOpenSSLContextFactory
 from twisted.internet.tcp import Port
@@ -144,10 +144,10 @@ if TYPE_CHECKING:
     from synapse.handlers.saml import SamlHandler
 
 
-T = TypeVar("T", bound=Callable[..., Any])
+T = TypeVar("T")
 
 
-def cache_in_self(builder: T) -> T:
+def cache_in_self(builder: Callable[["HomeServer"], T]) -> Callable[["HomeServer"], T]:
     """Wraps a function called e.g. `get_foo`, checking if `self.foo` exists and
     returning if so. If not, calls the given function and sets `self.foo` to it.
 
@@ -166,7 +166,7 @@ def cache_in_self(builder: T) -> T:
     building = [False]
 
     @functools.wraps(builder)
-    def _get(self):
+    def _get(self: "HomeServer") -> T:
         try:
             return getattr(self, depname)
         except AttributeError:
@@ -185,9 +185,7 @@ def cache_in_self(builder: T) -> T:
 
         return dep
 
-    # We cast here as we need to tell mypy that `_get` has the same signature as
-    # `builder`.
-    return cast(T, _get)
+    return _get
 
 
 class HomeServer(metaclass=abc.ABCMeta):
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 41d9111019..481fec72fe 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -37,6 +37,8 @@ class SQLBaseStore(metaclass=ABCMeta):
     per data store (and not one per physical database).
     """
 
+    db_pool: DatabasePool
+
     def __init__(
         self,
         database: DatabasePool,
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index e20c5c5302..feaa6cdd07 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -499,6 +499,7 @@ class DatabasePool:
     """
 
     _TXN_ID = 0
+    engine: BaseDatabaseEngine
 
     def __init__(
         self,
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 1536937b67..cb66376fb4 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -306,7 +306,7 @@ class PersistEventsStore:
 
         # The set of event_ids to return. This includes all soft-failed events
         # and their prev events.
-        existing_prevs = set()
+        existing_prevs: Set[str] = set()
 
         def _get_prevs_before_rejected_txn(
             txn: LoggingTransaction, batch: Collection[str]