diff --git a/changelog.d/12716.misc b/changelog.d/12716.misc
new file mode 100644
index 0000000000..b07e1b52ee
--- /dev/null
+++ b/changelog.d/12716.misc
@@ -0,0 +1 @@
+Add type annotations to increase the number of modules passing `disallow-untyped-defs`.
\ No newline at end of file
diff --git a/mypy.ini b/mypy.ini
index ba0de419f5..8478dd9e51 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -119,9 +119,18 @@ disallow_untyped_defs = True
[mypy-synapse.federation.transport.client]
disallow_untyped_defs = False
+[mypy-synapse.groups.*]
+disallow_untyped_defs = True
+
[mypy-synapse.handlers.*]
disallow_untyped_defs = True
+[mypy-synapse.http.federation.*]
+disallow_untyped_defs = True
+
+[mypy-synapse.http.request_metrics]
+disallow_untyped_defs = True
+
[mypy-synapse.http.server]
disallow_untyped_defs = True
@@ -196,12 +205,27 @@ disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.state_deltas]
disallow_untyped_defs = True
+[mypy-synapse.storage.databases.main.stream]
+disallow_untyped_defs = True
+
[mypy-synapse.storage.databases.main.transactions]
disallow_untyped_defs = True
[mypy-synapse.storage.databases.main.user_erasure_store]
disallow_untyped_defs = True
+[mypy-synapse.storage.prepare_database]
+disallow_untyped_defs = True
+
+[mypy-synapse.storage.persist_events]
+disallow_untyped_defs = True
+
+[mypy-synapse.storage.state]
+disallow_untyped_defs = True
+
+[mypy-synapse.storage.types]
+disallow_untyped_defs = True
+
[mypy-synapse.storage.util.*]
disallow_untyped_defs = True
diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py
index 4c3a5a6e24..dfd24af695 100644
--- a/synapse/groups/groups_server.py
+++ b/synapse/groups/groups_server.py
@@ -934,7 +934,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
# Before deleting the group lets kick everyone out of it
users = await self.store.get_users_in_group(group_id, include_private=True)
- async def _kick_user_from_group(user_id):
+ async def _kick_user_from_group(user_id: str) -> None:
if self.hs.is_mine_id(user_id):
groups_local = self.hs.get_groups_local_handler()
assert isinstance(
diff --git a/synapse/http/client.py b/synapse/http/client.py
index b2c9a7c670..084d0a5b84 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -43,8 +43,10 @@ from twisted.internet import defer, error as twisted_error, protocol, ssl
from twisted.internet.address import IPv4Address, IPv6Address
from twisted.internet.interfaces import (
IAddress,
+ IDelayedCall,
IHostResolution,
IReactorPluggableNameResolver,
+ IReactorTime,
IResolutionReceiver,
ITCPTransport,
)
@@ -121,13 +123,15 @@ def check_against_blacklist(
_EPSILON = 0.00000001
-def _make_scheduler(reactor):
+def _make_scheduler(
+ reactor: IReactorTime,
+) -> Callable[[Callable[[], object]], IDelayedCall]:
"""Makes a schedular suitable for a Cooperator using the given reactor.
(This is effectively just a copy from `twisted.internet.task`)
"""
- def _scheduler(x):
+ def _scheduler(x: Callable[[], object]) -> IDelayedCall:
return reactor.callLater(_EPSILON, x)
return _scheduler
@@ -775,7 +779,7 @@ class SimpleHttpClient:
)
-def _timeout_to_request_timed_out_error(f: Failure):
+def _timeout_to_request_timed_out_error(f: Failure) -> Failure:
if f.check(twisted_error.TimeoutError, twisted_error.ConnectingCancelledError):
# The TCP connection has its own timeout (set by the 'connectTimeout' param
# on the Agent), which raises twisted_error.TimeoutError exception.
@@ -809,7 +813,7 @@ class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
def __init__(self, deferred: defer.Deferred):
self.deferred = deferred
- def _maybe_fail(self):
+ def _maybe_fail(self) -> None:
"""
Report a max size exceed error and disconnect the first time this is called.
"""
@@ -933,12 +937,12 @@ class InsecureInterceptableContextFactory(ssl.ContextFactory):
Do not use this since it allows an attacker to intercept your communications.
"""
- def __init__(self):
+ def __init__(self) -> None:
self._context = SSL.Context(SSL.SSLv23_METHOD)
self._context.set_verify(VERIFY_NONE, lambda *_: False)
def getContext(self, hostname=None, port=None):
return self._context
- def creatorForNetloc(self, hostname, port):
+ def creatorForNetloc(self, hostname: bytes, port: int):
return self
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index a8a520f809..2f0177f1e2 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -239,7 +239,7 @@ class MatrixHostnameEndpointFactory:
self._srv_resolver = srv_resolver
- def endpointForURI(self, parsed_uri: URI):
+ def endpointForURI(self, parsed_uri: URI) -> "MatrixHostnameEndpoint":
return MatrixHostnameEndpoint(
self._reactor,
self._proxy_reactor,
diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py
index f68646fd0d..de0e882b33 100644
--- a/synapse/http/federation/srv_resolver.py
+++ b/synapse/http/federation/srv_resolver.py
@@ -16,7 +16,7 @@
import logging
import random
import time
-from typing import Callable, Dict, List
+from typing import Any, Callable, Dict, List
import attr
@@ -109,7 +109,7 @@ class SrvResolver:
def __init__(
self,
- dns_client=client,
+ dns_client: Any = client,
cache: Dict[bytes, List[Server]] = SERVER_CACHE,
get_time: Callable[[], float] = time.time,
):
diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py
index 43f2140429..71b685fade 100644
--- a/synapse/http/federation/well_known_resolver.py
+++ b/synapse/http/federation/well_known_resolver.py
@@ -74,9 +74,9 @@ _well_known_cache: TTLCache[bytes, Optional[bytes]] = TTLCache("well-known")
_had_valid_well_known_cache: TTLCache[bytes, bool] = TTLCache("had-valid-well-known")
-@attr.s(slots=True, frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class WellKnownLookupResult:
- delegated_server = attr.ib()
+ delegated_server: Optional[bytes]
class WellKnownResolver:
@@ -336,4 +336,4 @@ def _parse_cache_control(headers: Headers) -> Dict[bytes, Optional[bytes]]:
class _FetchWellKnownFailure(Exception):
# True if we didn't get a non-5xx HTTP response, i.e. this may or may not be
# a temporary failure.
- temporary = attr.ib()
+ temporary: bool = attr.ib()
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index c2ec3caa0e..725b5c33b8 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -23,6 +23,8 @@ from http import HTTPStatus
from io import BytesIO, StringIO
from typing import (
TYPE_CHECKING,
+ Any,
+ BinaryIO,
Callable,
Dict,
Generic,
@@ -44,7 +46,7 @@ from typing_extensions import Literal
from twisted.internet import defer
from twisted.internet.error import DNSLookupError
from twisted.internet.interfaces import IReactorTime
-from twisted.internet.task import _EPSILON, Cooperator
+from twisted.internet.task import Cooperator
from twisted.web.client import ResponseFailed
from twisted.web.http_headers import Headers
from twisted.web.iweb import IBodyProducer, IResponse
@@ -58,11 +60,13 @@ from synapse.api.errors import (
RequestSendFailed,
SynapseError,
)
+from synapse.crypto.context_factory import FederationPolicyForHTTPS
from synapse.http import QuieterFileBodyProducer
from synapse.http.client import (
BlacklistingAgentWrapper,
BodyExceededMaxSize,
ByteWriteable,
+ _make_scheduler,
encode_query_args,
read_body_with_max_size,
)
@@ -181,7 +185,7 @@ class JsonParser(ByteParser[Union[JsonDict, list]]):
CONTENT_TYPE = "application/json"
- def __init__(self):
+ def __init__(self) -> None:
self._buffer = StringIO()
self._binary_wrapper = BinaryIOWrapper(self._buffer)
@@ -299,7 +303,9 @@ async def _handle_response(
class BinaryIOWrapper:
"""A wrapper for a TextIO which converts from bytes on the fly."""
- def __init__(self, file: typing.TextIO, encoding="utf-8", errors="strict"):
+ def __init__(
+ self, file: typing.TextIO, encoding: str = "utf-8", errors: str = "strict"
+ ):
self.decoder = codecs.getincrementaldecoder(encoding)(errors)
self.file = file
@@ -317,7 +323,11 @@ class MatrixFederationHttpClient:
requests.
"""
- def __init__(self, hs: "HomeServer", tls_client_options_factory):
+ def __init__(
+ self,
+ hs: "HomeServer",
+ tls_client_options_factory: Optional[FederationPolicyForHTTPS],
+ ):
self.hs = hs
self.signing_key = hs.signing_key
self.server_name = hs.hostname
@@ -348,10 +358,7 @@ class MatrixFederationHttpClient:
self.version_string_bytes = hs.version_string.encode("ascii")
self.default_timeout = 60
- def schedule(x):
- self.reactor.callLater(_EPSILON, x)
-
- self._cooperator = Cooperator(scheduler=schedule)
+ self._cooperator = Cooperator(scheduler=_make_scheduler(self.reactor))
self._sleeper = AwakenableSleeper(self.reactor)
@@ -364,7 +371,7 @@ class MatrixFederationHttpClient:
self,
request: MatrixFederationRequest,
try_trailing_slash_on_400: bool = False,
- **send_request_args,
+ **send_request_args: Any,
) -> IResponse:
"""Wrapper for _send_request which can optionally retry the request
upon receiving a combination of a 400 HTTP response code and a
@@ -1159,7 +1166,7 @@ class MatrixFederationHttpClient:
self,
destination: str,
path: str,
- output_stream,
+ output_stream: BinaryIO,
args: Optional[QueryParams] = None,
retry_on_dns_fail: bool = True,
max_size: Optional[int] = None,
@@ -1250,10 +1257,10 @@ class MatrixFederationHttpClient:
return length, headers
-def _flatten_response_never_received(e):
+def _flatten_response_never_received(e: BaseException) -> str:
if hasattr(e, "reasons"):
reasons = ", ".join(
- _flatten_response_never_received(f.value) for f in e.reasons
+ _flatten_response_never_received(f.value) for f in e.reasons # type: ignore[attr-defined]
)
return "%s:[%s]" % (type(e).__name__, reasons)
diff --git a/synapse/http/request_metrics.py b/synapse/http/request_metrics.py
index 4886626d50..2b6d113544 100644
--- a/synapse/http/request_metrics.py
+++ b/synapse/http/request_metrics.py
@@ -162,7 +162,7 @@ class RequestMetrics:
with _in_flight_requests_lock:
_in_flight_requests.add(self)
- def stop(self, time_sec, response_code, sent_bytes):
+ def stop(self, time_sec: float, response_code: int, sent_bytes: int) -> None:
with _in_flight_requests_lock:
_in_flight_requests.discard(self)
@@ -186,13 +186,13 @@ class RequestMetrics:
)
return
- response_code = str(response_code)
+ response_code_str = str(response_code)
- outgoing_responses_counter.labels(self.method, response_code).inc()
+ outgoing_responses_counter.labels(self.method, response_code_str).inc()
response_count.labels(self.method, self.name, tag).inc()
- response_timer.labels(self.method, self.name, tag, response_code).observe(
+ response_timer.labels(self.method, self.name, tag, response_code_str).observe(
time_sec - self.start_ts
)
@@ -221,7 +221,7 @@ class RequestMetrics:
# flight.
self.update_metrics()
- def update_metrics(self):
+ def update_metrics(self) -> None:
"""Updates the in flight metrics with values from this request."""
if not self.start_context:
logger.error(
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 41f566b648..5ddb58a8a2 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -31,6 +31,7 @@ from typing import (
List,
Optional,
Tuple,
+ Type,
TypeVar,
cast,
overload,
@@ -41,6 +42,7 @@ from prometheus_client import Histogram
from typing_extensions import Concatenate, Literal, ParamSpec
from twisted.enterprise import adbapi
+from twisted.internet.interfaces import IReactorCore
from synapse.api.errors import StoreError
from synapse.config.database import DatabaseConnectionConfig
@@ -92,7 +94,9 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = {
def make_pool(
- reactor, db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
+ reactor: IReactorCore,
+ db_config: DatabaseConnectionConfig,
+ engine: BaseDatabaseEngine,
) -> adbapi.ConnectionPool:
"""Get the connection pool for the database."""
@@ -101,7 +105,7 @@ def make_pool(
db_args = dict(db_config.config.get("args", {}))
db_args.setdefault("cp_reconnect", True)
- def _on_new_connection(conn):
+ def _on_new_connection(conn: Connection) -> None:
# Ensure we have a logging context so we can correctly track queries,
# etc.
with LoggingContext("db.on_new_connection"):
@@ -157,7 +161,11 @@ class LoggingDatabaseConnection:
default_txn_name: str
def cursor(
- self, *, txn_name=None, after_callbacks=None, exception_callbacks=None
+ self,
+ *,
+ txn_name: Optional[str] = None,
+ after_callbacks: Optional[List["_CallbackListEntry"]] = None,
+ exception_callbacks: Optional[List["_CallbackListEntry"]] = None,
) -> "LoggingTransaction":
if not txn_name:
txn_name = self.default_txn_name
@@ -183,11 +191,16 @@ class LoggingDatabaseConnection:
self.conn.__enter__()
return self
- def __exit__(self, exc_type, exc_value, traceback) -> Optional[bool]:
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_value: Optional[BaseException],
+ traceback: Optional[types.TracebackType],
+ ) -> Optional[bool]:
return self.conn.__exit__(exc_type, exc_value, traceback)
# Proxy through any unknown lookups to the DB conn class.
- def __getattr__(self, name):
+ def __getattr__(self, name: str) -> Any:
return getattr(self.conn, name)
@@ -391,17 +404,22 @@ class LoggingTransaction:
def __enter__(self) -> "LoggingTransaction":
return self
- def __exit__(self, exc_type, exc_value, traceback):
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_value: Optional[BaseException],
+ traceback: Optional[types.TracebackType],
+ ) -> None:
self.close()
class PerformanceCounters:
- def __init__(self):
- self.current_counters = {}
- self.previous_counters = {}
+ def __init__(self) -> None:
+ self.current_counters: Dict[str, Tuple[int, float]] = {}
+ self.previous_counters: Dict[str, Tuple[int, float]] = {}
def update(self, key: str, duration_secs: float) -> None:
- count, cum_time = self.current_counters.get(key, (0, 0))
+ count, cum_time = self.current_counters.get(key, (0, 0.0))
count += 1
cum_time += duration_secs
self.current_counters[key] = (count, cum_time)
@@ -527,7 +545,7 @@ class DatabasePool:
def start_profiling(self) -> None:
self._previous_loop_ts = monotonic_time()
- def loop():
+ def loop() -> None:
curr = self._current_txn_total_time
prev = self._previous_txn_total_time
self._previous_txn_total_time = curr
@@ -1186,7 +1204,7 @@ class DatabasePool:
if lock:
self.engine.lock_table(txn, table)
- def _getwhere(key):
+ def _getwhere(key: str) -> str:
# If the value we're passing in is None (aka NULL), we need to use
# IS, not =, as NULL = NULL equals NULL (False).
if keyvalues[key] is None:
@@ -2258,7 +2276,7 @@ class DatabasePool:
term: Optional[str],
col: str,
retcols: Collection[str],
- desc="simple_search_list",
+ desc: str = "simple_search_list",
) -> Optional[List[Dict[str, Any]]]:
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py
index 1480a0f048..d03555a585 100644
--- a/synapse/storage/databases/main/metrics.py
+++ b/synapse/storage/databases/main/metrics.py
@@ -23,6 +23,7 @@ from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.event_push_actions import (
EventPushActionsWorkerStore,
)
+from synapse.storage.types import Cursor
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -71,7 +72,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
self._last_user_visit_update = self._get_start_of_day()
@wrap_as_background_process("read_forward_extremities")
- async def _read_forward_extremities(self):
+ async def _read_forward_extremities(self) -> None:
def fetch(txn):
txn.execute(
"""
@@ -95,7 +96,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
(x[0] - 1) * x[1] for x in res if x[1]
)
- async def count_daily_e2ee_messages(self):
+ async def count_daily_e2ee_messages(self) -> int:
"""
Returns an estimate of the number of messages sent in the last day.
@@ -115,7 +116,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
return await self.db_pool.runInteraction("count_e2ee_messages", _count_messages)
- async def count_daily_sent_e2ee_messages(self):
+ async def count_daily_sent_e2ee_messages(self) -> int:
def _count_messages(txn):
# This is good enough as if you have silly characters in your own
# hostname then that's your own fault.
@@ -136,7 +137,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
"count_daily_sent_e2ee_messages", _count_messages
)
- async def count_daily_active_e2ee_rooms(self):
+ async def count_daily_active_e2ee_rooms(self) -> int:
def _count(txn):
sql = """
SELECT COUNT(DISTINCT room_id) FROM events
@@ -151,7 +152,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
"count_daily_active_e2ee_rooms", _count
)
- async def count_daily_messages(self):
+ async def count_daily_messages(self) -> int:
"""
Returns an estimate of the number of messages sent in the last day.
@@ -171,7 +172,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
return await self.db_pool.runInteraction("count_messages", _count_messages)
- async def count_daily_sent_messages(self):
+ async def count_daily_sent_messages(self) -> int:
def _count_messages(txn):
# This is good enough as if you have silly characters in your own
# hostname then that's your own fault.
@@ -192,7 +193,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
"count_daily_sent_messages", _count_messages
)
- async def count_daily_active_rooms(self):
+ async def count_daily_active_rooms(self) -> int:
def _count(txn):
sql = """
SELECT COUNT(DISTINCT room_id) FROM events
@@ -226,7 +227,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
"count_monthly_users", self._count_users, thirty_days_ago
)
- def _count_users(self, txn, time_from):
+ def _count_users(self, txn: Cursor, time_from: int) -> int:
"""
Returns number of users seen in the past time_from period
"""
@@ -238,7 +239,10 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
) u
"""
txn.execute(sql, (time_from,))
- (count,) = txn.fetchone()
+ # Mypy knows that fetchone() might return None if there are no rows.
+ # We know better: "SELECT COUNT(...) FROM ..." without any GROUP BY always
+ # returns exactly one row.
+ (count,) = txn.fetchone() # type: ignore[misc]
return count
async def count_r30_users(self) -> Dict[str, int]:
@@ -453,7 +457,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
"count_r30v2_users", _count_r30v2_users
)
- def _get_start_of_day(self):
+ def _get_start_of_day(self) -> int:
"""
Returns millisecond unixtime for start of UTC day.
"""
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 4e1d9647b7..59bbca2e32 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -798,9 +798,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
self,
txn: LoggingTransaction,
event_id: str,
- allow_none=False,
- ) -> int:
- return self.db_pool.simple_select_one_onecol_txn(
+ allow_none: bool = False,
+ ) -> Optional[int]:
+ # Type ignore: we pass keyvalues a Dict[str, str]; the function wants
+ # Dict[str, Any]. I think mypy is unhappy because Dict is invariant?
+ return self.db_pool.simple_select_one_onecol_txn( # type: ignore[call-overload]
txn=txn,
table="events",
keyvalues={"event_id": event_id},
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index a7f6338e05..0fc282866b 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -25,6 +25,7 @@ from typing import (
Collection,
Deque,
Dict,
+ Generator,
Generic,
Iterable,
List,
@@ -207,7 +208,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
return res
- def _handle_queue(self, room_id):
+ def _handle_queue(self, room_id: str) -> None:
"""Attempts to handle the queue for a room if not already being handled.
The queue's callback will be invoked with for each item in the queue,
@@ -227,7 +228,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
self._currently_persisting_rooms.add(room_id)
- async def handle_queue_loop():
+ async def handle_queue_loop() -> None:
try:
queue = self._get_drainining_queue(room_id)
for item in queue:
@@ -250,15 +251,17 @@ class _EventPeristenceQueue(Generic[_PersistResult]):
with PreserveLoggingContext():
item.deferred.callback(ret)
finally:
- queue = self._event_persist_queues.pop(room_id, None)
- if queue:
- self._event_persist_queues[room_id] = queue
+ remaining_queue = self._event_persist_queues.pop(room_id, None)
+ if remaining_queue:
+ self._event_persist_queues[room_id] = remaining_queue
self._currently_persisting_rooms.discard(room_id)
# set handle_queue_loop off in the background
run_as_background_process("persist_events", handle_queue_loop)
- def _get_drainining_queue(self, room_id):
+ def _get_drainining_queue(
+ self, room_id: str
+ ) -> Generator[_EventPersistQueueItem, None, None]:
queue = self._event_persist_queues.setdefault(room_id, deque())
try:
@@ -317,7 +320,9 @@ class EventsPersistenceStorage:
for event, ctx in events_and_contexts:
partitioned.setdefault(event.room_id, []).append((event, ctx))
- async def enqueue(item):
+ async def enqueue(
+ item: Tuple[str, List[Tuple[EventBase, EventContext]]]
+ ) -> Dict[str, str]:
room_id, evs_ctxs = item
return await self._event_persist_queue.add_to_queue(
room_id, evs_ctxs, backfilled=backfilled
@@ -1102,7 +1107,7 @@ class EventsPersistenceStorage:
return False
- async def _handle_potentially_left_users(self, user_ids: Set[str]):
+ async def _handle_potentially_left_users(self, user_ids: Set[str]) -> None:
"""Given a set of remote users check if the server still shares a room with
them. If not then mark those users' device cache as stale.
"""
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 546d6bae6e..c33df42084 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -85,7 +85,7 @@ def prepare_database(
database_engine: BaseDatabaseEngine,
config: Optional[HomeServerConfig],
databases: Collection[str] = ("main", "state"),
-):
+) -> None:
"""Prepares a physical database for usage. Will either create all necessary tables
or upgrade from an older schema version.
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index d1d5859214..d4a1bd4f9d 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -62,7 +62,7 @@ class StateFilter:
types: "frozendict[str, Optional[FrozenSet[str]]]"
include_others: bool = False
- def __attrs_post_init__(self):
+ def __attrs_post_init__(self) -> None:
# If `include_others` is set we canonicalise the filter by removing
# wildcards from the types dictionary
if self.include_others:
@@ -138,7 +138,9 @@ class StateFilter:
)
@staticmethod
- def freeze(types: Mapping[str, Optional[Collection[str]]], include_others: bool):
+ def freeze(
+ types: Mapping[str, Optional[Collection[str]]], include_others: bool
+ ) -> "StateFilter":
"""
Returns a (frozen) StateFilter with the same contents as the parameters
specified here, which can be made of mutable types.
diff --git a/synapse/storage/types.py b/synapse/storage/types.py
index d7d6f1d90e..40536c1830 100644
--- a/synapse/storage/types.py
+++ b/synapse/storage/types.py
@@ -11,7 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Iterator, List, Mapping, Optional, Sequence, Tuple, Union
+from types import TracebackType
+from typing import Any, Iterator, List, Mapping, Optional, Sequence, Tuple, Type, Union
from typing_extensions import Protocol
@@ -86,5 +87,10 @@ class Connection(Protocol):
def __enter__(self) -> "Connection":
...
- def __exit__(self, exc_type, exc_value, traceback) -> Optional[bool]:
+ def __exit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc_value: Optional[BaseException],
+ traceback: Optional[TracebackType],
+ ) -> Optional[bool]:
...
|