diff options
Diffstat (limited to 'synapse')
-rw-r--r-- | synapse/groups/groups_server.py | 2 | ||||
-rw-r--r-- | synapse/http/client.py | 16 | ||||
-rw-r--r-- | synapse/http/federation/matrix_federation_agent.py | 2 | ||||
-rw-r--r-- | synapse/http/federation/srv_resolver.py | 4 | ||||
-rw-r--r-- | synapse/http/federation/well_known_resolver.py | 6 | ||||
-rw-r--r-- | synapse/http/matrixfederationclient.py | 31 | ||||
-rw-r--r-- | synapse/http/request_metrics.py | 10 | ||||
-rw-r--r-- | synapse/storage/database.py | 44 | ||||
-rw-r--r-- | synapse/storage/databases/main/metrics.py | 24 | ||||
-rw-r--r-- | synapse/storage/databases/main/stream.py | 8 | ||||
-rw-r--r-- | synapse/storage/persist_events.py | 21 | ||||
-rw-r--r-- | synapse/storage/prepare_database.py | 2 | ||||
-rw-r--r-- | synapse/storage/state.py | 6 | ||||
-rw-r--r-- | synapse/storage/types.py | 10 |
14 files changed, 117 insertions, 69 deletions
diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py index 4c3a5a6e24..dfd24af695 100644 --- a/synapse/groups/groups_server.py +++ b/synapse/groups/groups_server.py @@ -934,7 +934,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler): # Before deleting the group lets kick everyone out of it users = await self.store.get_users_in_group(group_id, include_private=True) - async def _kick_user_from_group(user_id): + async def _kick_user_from_group(user_id: str) -> None: if self.hs.is_mine_id(user_id): groups_local = self.hs.get_groups_local_handler() assert isinstance( diff --git a/synapse/http/client.py b/synapse/http/client.py index b2c9a7c670..084d0a5b84 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -43,8 +43,10 @@ from twisted.internet import defer, error as twisted_error, protocol, ssl from twisted.internet.address import IPv4Address, IPv6Address from twisted.internet.interfaces import ( IAddress, + IDelayedCall, IHostResolution, IReactorPluggableNameResolver, + IReactorTime, IResolutionReceiver, ITCPTransport, ) @@ -121,13 +123,15 @@ def check_against_blacklist( _EPSILON = 0.00000001 -def _make_scheduler(reactor): +def _make_scheduler( + reactor: IReactorTime, +) -> Callable[[Callable[[], object]], IDelayedCall]: """Makes a schedular suitable for a Cooperator using the given reactor. (This is effectively just a copy from `twisted.internet.task`) """ - def _scheduler(x): + def _scheduler(x: Callable[[], object]) -> IDelayedCall: return reactor.callLater(_EPSILON, x) return _scheduler @@ -775,7 +779,7 @@ class SimpleHttpClient: ) -def _timeout_to_request_timed_out_error(f: Failure): +def _timeout_to_request_timed_out_error(f: Failure) -> Failure: if f.check(twisted_error.TimeoutError, twisted_error.ConnectingCancelledError): # The TCP connection has its own timeout (set by the 'connectTimeout' param # on the Agent), which raises twisted_error.TimeoutError exception. @@ -809,7 +813,7 @@ class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol): def __init__(self, deferred: defer.Deferred): self.deferred = deferred - def _maybe_fail(self): + def _maybe_fail(self) -> None: """ Report a max size exceed error and disconnect the first time this is called. """ @@ -933,12 +937,12 @@ class InsecureInterceptableContextFactory(ssl.ContextFactory): Do not use this since it allows an attacker to intercept your communications. """ - def __init__(self): + def __init__(self) -> None: self._context = SSL.Context(SSL.SSLv23_METHOD) self._context.set_verify(VERIFY_NONE, lambda *_: False) def getContext(self, hostname=None, port=None): return self._context - def creatorForNetloc(self, hostname, port): + def creatorForNetloc(self, hostname: bytes, port: int): return self diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py index a8a520f809..2f0177f1e2 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py @@ -239,7 +239,7 @@ class MatrixHostnameEndpointFactory: self._srv_resolver = srv_resolver - def endpointForURI(self, parsed_uri: URI): + def endpointForURI(self, parsed_uri: URI) -> "MatrixHostnameEndpoint": return MatrixHostnameEndpoint( self._reactor, self._proxy_reactor, diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py index f68646fd0d..de0e882b33 100644 --- a/synapse/http/federation/srv_resolver.py +++ b/synapse/http/federation/srv_resolver.py @@ -16,7 +16,7 @@ import logging import random import time -from typing import Callable, Dict, List +from typing import Any, Callable, Dict, List import attr @@ -109,7 +109,7 @@ class SrvResolver: def __init__( self, - dns_client=client, + dns_client: Any = client, cache: Dict[bytes, List[Server]] = SERVER_CACHE, get_time: Callable[[], float] = time.time, ): diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py index 43f2140429..71b685fade 100644 --- a/synapse/http/federation/well_known_resolver.py +++ b/synapse/http/federation/well_known_resolver.py @@ -74,9 +74,9 @@ _well_known_cache: TTLCache[bytes, Optional[bytes]] = TTLCache("well-known") _had_valid_well_known_cache: TTLCache[bytes, bool] = TTLCache("had-valid-well-known") -@attr.s(slots=True, frozen=True) +@attr.s(slots=True, frozen=True, auto_attribs=True) class WellKnownLookupResult: - delegated_server = attr.ib() + delegated_server: Optional[bytes] class WellKnownResolver: @@ -336,4 +336,4 @@ def _parse_cache_control(headers: Headers) -> Dict[bytes, Optional[bytes]]: class _FetchWellKnownFailure(Exception): # True if we didn't get a non-5xx HTTP response, i.e. this may or may not be # a temporary failure. - temporary = attr.ib() + temporary: bool = attr.ib() diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index c2ec3caa0e..725b5c33b8 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -23,6 +23,8 @@ from http import HTTPStatus from io import BytesIO, StringIO from typing import ( TYPE_CHECKING, + Any, + BinaryIO, Callable, Dict, Generic, @@ -44,7 +46,7 @@ from typing_extensions import Literal from twisted.internet import defer from twisted.internet.error import DNSLookupError from twisted.internet.interfaces import IReactorTime -from twisted.internet.task import _EPSILON, Cooperator +from twisted.internet.task import Cooperator from twisted.web.client import ResponseFailed from twisted.web.http_headers import Headers from twisted.web.iweb import IBodyProducer, IResponse @@ -58,11 +60,13 @@ from synapse.api.errors import ( RequestSendFailed, SynapseError, ) +from synapse.crypto.context_factory import FederationPolicyForHTTPS from synapse.http import QuieterFileBodyProducer from synapse.http.client import ( BlacklistingAgentWrapper, BodyExceededMaxSize, ByteWriteable, + _make_scheduler, encode_query_args, read_body_with_max_size, ) @@ -181,7 +185,7 @@ class JsonParser(ByteParser[Union[JsonDict, list]]): CONTENT_TYPE = "application/json" - def __init__(self): + def __init__(self) -> None: self._buffer = StringIO() self._binary_wrapper = BinaryIOWrapper(self._buffer) @@ -299,7 +303,9 @@ async def _handle_response( class BinaryIOWrapper: """A wrapper for a TextIO which converts from bytes on the fly.""" - def __init__(self, file: typing.TextIO, encoding="utf-8", errors="strict"): + def __init__( + self, file: typing.TextIO, encoding: str = "utf-8", errors: str = "strict" + ): self.decoder = codecs.getincrementaldecoder(encoding)(errors) self.file = file @@ -317,7 +323,11 @@ class MatrixFederationHttpClient: requests. """ - def __init__(self, hs: "HomeServer", tls_client_options_factory): + def __init__( + self, + hs: "HomeServer", + tls_client_options_factory: Optional[FederationPolicyForHTTPS], + ): self.hs = hs self.signing_key = hs.signing_key self.server_name = hs.hostname @@ -348,10 +358,7 @@ class MatrixFederationHttpClient: self.version_string_bytes = hs.version_string.encode("ascii") self.default_timeout = 60 - def schedule(x): - self.reactor.callLater(_EPSILON, x) - - self._cooperator = Cooperator(scheduler=schedule) + self._cooperator = Cooperator(scheduler=_make_scheduler(self.reactor)) self._sleeper = AwakenableSleeper(self.reactor) @@ -364,7 +371,7 @@ class MatrixFederationHttpClient: self, request: MatrixFederationRequest, try_trailing_slash_on_400: bool = False, - **send_request_args, + **send_request_args: Any, ) -> IResponse: """Wrapper for _send_request which can optionally retry the request upon receiving a combination of a 400 HTTP response code and a @@ -1159,7 +1166,7 @@ class MatrixFederationHttpClient: self, destination: str, path: str, - output_stream, + output_stream: BinaryIO, args: Optional[QueryParams] = None, retry_on_dns_fail: bool = True, max_size: Optional[int] = None, @@ -1250,10 +1257,10 @@ class MatrixFederationHttpClient: return length, headers -def _flatten_response_never_received(e): +def _flatten_response_never_received(e: BaseException) -> str: if hasattr(e, "reasons"): reasons = ", ".join( - _flatten_response_never_received(f.value) for f in e.reasons + _flatten_response_never_received(f.value) for f in e.reasons # type: ignore[attr-defined] ) return "%s:[%s]" % (type(e).__name__, reasons) diff --git a/synapse/http/request_metrics.py b/synapse/http/request_metrics.py index 4886626d50..2b6d113544 100644 --- a/synapse/http/request_metrics.py +++ b/synapse/http/request_metrics.py @@ -162,7 +162,7 @@ class RequestMetrics: with _in_flight_requests_lock: _in_flight_requests.add(self) - def stop(self, time_sec, response_code, sent_bytes): + def stop(self, time_sec: float, response_code: int, sent_bytes: int) -> None: with _in_flight_requests_lock: _in_flight_requests.discard(self) @@ -186,13 +186,13 @@ class RequestMetrics: ) return - response_code = str(response_code) + response_code_str = str(response_code) - outgoing_responses_counter.labels(self.method, response_code).inc() + outgoing_responses_counter.labels(self.method, response_code_str).inc() response_count.labels(self.method, self.name, tag).inc() - response_timer.labels(self.method, self.name, tag, response_code).observe( + response_timer.labels(self.method, self.name, tag, response_code_str).observe( time_sec - self.start_ts ) @@ -221,7 +221,7 @@ class RequestMetrics: # flight. self.update_metrics() - def update_metrics(self): + def update_metrics(self) -> None: """Updates the in flight metrics with values from this request.""" if not self.start_context: logger.error( diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 41f566b648..5ddb58a8a2 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -31,6 +31,7 @@ from typing import ( List, Optional, Tuple, + Type, TypeVar, cast, overload, @@ -41,6 +42,7 @@ from prometheus_client import Histogram from typing_extensions import Concatenate, Literal, ParamSpec from twisted.enterprise import adbapi +from twisted.internet.interfaces import IReactorCore from synapse.api.errors import StoreError from synapse.config.database import DatabaseConnectionConfig @@ -92,7 +94,9 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = { def make_pool( - reactor, db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine + reactor: IReactorCore, + db_config: DatabaseConnectionConfig, + engine: BaseDatabaseEngine, ) -> adbapi.ConnectionPool: """Get the connection pool for the database.""" @@ -101,7 +105,7 @@ def make_pool( db_args = dict(db_config.config.get("args", {})) db_args.setdefault("cp_reconnect", True) - def _on_new_connection(conn): + def _on_new_connection(conn: Connection) -> None: # Ensure we have a logging context so we can correctly track queries, # etc. with LoggingContext("db.on_new_connection"): @@ -157,7 +161,11 @@ class LoggingDatabaseConnection: default_txn_name: str def cursor( - self, *, txn_name=None, after_callbacks=None, exception_callbacks=None + self, + *, + txn_name: Optional[str] = None, + after_callbacks: Optional[List["_CallbackListEntry"]] = None, + exception_callbacks: Optional[List["_CallbackListEntry"]] = None, ) -> "LoggingTransaction": if not txn_name: txn_name = self.default_txn_name @@ -183,11 +191,16 @@ class LoggingDatabaseConnection: self.conn.__enter__() return self - def __exit__(self, exc_type, exc_value, traceback) -> Optional[bool]: + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[types.TracebackType], + ) -> Optional[bool]: return self.conn.__exit__(exc_type, exc_value, traceback) # Proxy through any unknown lookups to the DB conn class. - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: return getattr(self.conn, name) @@ -391,17 +404,22 @@ class LoggingTransaction: def __enter__(self) -> "LoggingTransaction": return self - def __exit__(self, exc_type, exc_value, traceback): + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[types.TracebackType], + ) -> None: self.close() class PerformanceCounters: - def __init__(self): - self.current_counters = {} - self.previous_counters = {} + def __init__(self) -> None: + self.current_counters: Dict[str, Tuple[int, float]] = {} + self.previous_counters: Dict[str, Tuple[int, float]] = {} def update(self, key: str, duration_secs: float) -> None: - count, cum_time = self.current_counters.get(key, (0, 0)) + count, cum_time = self.current_counters.get(key, (0, 0.0)) count += 1 cum_time += duration_secs self.current_counters[key] = (count, cum_time) @@ -527,7 +545,7 @@ class DatabasePool: def start_profiling(self) -> None: self._previous_loop_ts = monotonic_time() - def loop(): + def loop() -> None: curr = self._current_txn_total_time prev = self._previous_txn_total_time self._previous_txn_total_time = curr @@ -1186,7 +1204,7 @@ class DatabasePool: if lock: self.engine.lock_table(txn, table) - def _getwhere(key): + def _getwhere(key: str) -> str: # If the value we're passing in is None (aka NULL), we need to use # IS, not =, as NULL = NULL equals NULL (False). if keyvalues[key] is None: @@ -2258,7 +2276,7 @@ class DatabasePool: term: Optional[str], col: str, retcols: Collection[str], - desc="simple_search_list", + desc: str = "simple_search_list", ) -> Optional[List[Dict[str, Any]]]: """Executes a SELECT query on the named table, which may return zero or more rows, returning the result as a list of dicts. diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py index 1480a0f048..d03555a585 100644 --- a/synapse/storage/databases/main/metrics.py +++ b/synapse/storage/databases/main/metrics.py @@ -23,6 +23,7 @@ from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.databases.main.event_push_actions import ( EventPushActionsWorkerStore, ) +from synapse.storage.types import Cursor if TYPE_CHECKING: from synapse.server import HomeServer @@ -71,7 +72,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): self._last_user_visit_update = self._get_start_of_day() @wrap_as_background_process("read_forward_extremities") - async def _read_forward_extremities(self): + async def _read_forward_extremities(self) -> None: def fetch(txn): txn.execute( """ @@ -95,7 +96,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): (x[0] - 1) * x[1] for x in res if x[1] ) - async def count_daily_e2ee_messages(self): + async def count_daily_e2ee_messages(self) -> int: """ Returns an estimate of the number of messages sent in the last day. @@ -115,7 +116,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): return await self.db_pool.runInteraction("count_e2ee_messages", _count_messages) - async def count_daily_sent_e2ee_messages(self): + async def count_daily_sent_e2ee_messages(self) -> int: def _count_messages(txn): # This is good enough as if you have silly characters in your own # hostname then that's your own fault. @@ -136,7 +137,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): "count_daily_sent_e2ee_messages", _count_messages ) - async def count_daily_active_e2ee_rooms(self): + async def count_daily_active_e2ee_rooms(self) -> int: def _count(txn): sql = """ SELECT COUNT(DISTINCT room_id) FROM events @@ -151,7 +152,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): "count_daily_active_e2ee_rooms", _count ) - async def count_daily_messages(self): + async def count_daily_messages(self) -> int: """ Returns an estimate of the number of messages sent in the last day. @@ -171,7 +172,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): return await self.db_pool.runInteraction("count_messages", _count_messages) - async def count_daily_sent_messages(self): + async def count_daily_sent_messages(self) -> int: def _count_messages(txn): # This is good enough as if you have silly characters in your own # hostname then that's your own fault. @@ -192,7 +193,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): "count_daily_sent_messages", _count_messages ) - async def count_daily_active_rooms(self): + async def count_daily_active_rooms(self) -> int: def _count(txn): sql = """ SELECT COUNT(DISTINCT room_id) FROM events @@ -226,7 +227,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): "count_monthly_users", self._count_users, thirty_days_ago ) - def _count_users(self, txn, time_from): + def _count_users(self, txn: Cursor, time_from: int) -> int: """ Returns number of users seen in the past time_from period """ @@ -238,7 +239,10 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): ) u """ txn.execute(sql, (time_from,)) - (count,) = txn.fetchone() + # Mypy knows that fetchone() might return None if there are no rows. + # We know better: "SELECT COUNT(...) FROM ..." without any GROUP BY always + # returns exactly one row. + (count,) = txn.fetchone() # type: ignore[misc] return count async def count_r30_users(self) -> Dict[str, int]: @@ -453,7 +457,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): "count_r30v2_users", _count_r30v2_users ) - def _get_start_of_day(self): + def _get_start_of_day(self) -> int: """ Returns millisecond unixtime for start of UTC day. """ diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 4e1d9647b7..59bbca2e32 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -798,9 +798,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): self, txn: LoggingTransaction, event_id: str, - allow_none=False, - ) -> int: - return self.db_pool.simple_select_one_onecol_txn( + allow_none: bool = False, + ) -> Optional[int]: + # Type ignore: we pass keyvalues a Dict[str, str]; the function wants + # Dict[str, Any]. I think mypy is unhappy because Dict is invariant? + return self.db_pool.simple_select_one_onecol_txn( # type: ignore[call-overload] txn=txn, table="events", keyvalues={"event_id": event_id}, diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index a7f6338e05..0fc282866b 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -25,6 +25,7 @@ from typing import ( Collection, Deque, Dict, + Generator, Generic, Iterable, List, @@ -207,7 +208,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]): return res - def _handle_queue(self, room_id): + def _handle_queue(self, room_id: str) -> None: """Attempts to handle the queue for a room if not already being handled. The queue's callback will be invoked with for each item in the queue, @@ -227,7 +228,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]): self._currently_persisting_rooms.add(room_id) - async def handle_queue_loop(): + async def handle_queue_loop() -> None: try: queue = self._get_drainining_queue(room_id) for item in queue: @@ -250,15 +251,17 @@ class _EventPeristenceQueue(Generic[_PersistResult]): with PreserveLoggingContext(): item.deferred.callback(ret) finally: - queue = self._event_persist_queues.pop(room_id, None) - if queue: - self._event_persist_queues[room_id] = queue + remaining_queue = self._event_persist_queues.pop(room_id, None) + if remaining_queue: + self._event_persist_queues[room_id] = remaining_queue self._currently_persisting_rooms.discard(room_id) # set handle_queue_loop off in the background run_as_background_process("persist_events", handle_queue_loop) - def _get_drainining_queue(self, room_id): + def _get_drainining_queue( + self, room_id: str + ) -> Generator[_EventPersistQueueItem, None, None]: queue = self._event_persist_queues.setdefault(room_id, deque()) try: @@ -317,7 +320,9 @@ class EventsPersistenceStorage: for event, ctx in events_and_contexts: partitioned.setdefault(event.room_id, []).append((event, ctx)) - async def enqueue(item): + async def enqueue( + item: Tuple[str, List[Tuple[EventBase, EventContext]]] + ) -> Dict[str, str]: room_id, evs_ctxs = item return await self._event_persist_queue.add_to_queue( room_id, evs_ctxs, backfilled=backfilled @@ -1102,7 +1107,7 @@ class EventsPersistenceStorage: return False - async def _handle_potentially_left_users(self, user_ids: Set[str]): + async def _handle_potentially_left_users(self, user_ids: Set[str]) -> None: """Given a set of remote users check if the server still shares a room with them. If not then mark those users' device cache as stale. """ diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 546d6bae6e..c33df42084 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -85,7 +85,7 @@ def prepare_database( database_engine: BaseDatabaseEngine, config: Optional[HomeServerConfig], databases: Collection[str] = ("main", "state"), -): +) -> None: """Prepares a physical database for usage. Will either create all necessary tables or upgrade from an older schema version. diff --git a/synapse/storage/state.py b/synapse/storage/state.py index d1d5859214..d4a1bd4f9d 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -62,7 +62,7 @@ class StateFilter: types: "frozendict[str, Optional[FrozenSet[str]]]" include_others: bool = False - def __attrs_post_init__(self): + def __attrs_post_init__(self) -> None: # If `include_others` is set we canonicalise the filter by removing # wildcards from the types dictionary if self.include_others: @@ -138,7 +138,9 @@ class StateFilter: ) @staticmethod - def freeze(types: Mapping[str, Optional[Collection[str]]], include_others: bool): + def freeze( + types: Mapping[str, Optional[Collection[str]]], include_others: bool + ) -> "StateFilter": """ Returns a (frozen) StateFilter with the same contents as the parameters specified here, which can be made of mutable types. diff --git a/synapse/storage/types.py b/synapse/storage/types.py index d7d6f1d90e..40536c1830 100644 --- a/synapse/storage/types.py +++ b/synapse/storage/types.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Iterator, List, Mapping, Optional, Sequence, Tuple, Union +from types import TracebackType +from typing import Any, Iterator, List, Mapping, Optional, Sequence, Tuple, Type, Union from typing_extensions import Protocol @@ -86,5 +87,10 @@ class Connection(Protocol): def __enter__(self) -> "Connection": ... - def __exit__(self, exc_type, exc_value, traceback) -> Optional[bool]: + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> Optional[bool]: ... |