summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorDavid Robertson <davidr@element.io>2022-05-13 12:35:31 +0100
committerGitHub <noreply@github.com>2022-05-13 12:35:31 +0100
commitaec69d2481e9ea1d8ea1c0ffce1706a65a7896a8 (patch)
tree1017693baaa1bd87632b3a252395bb5a8503d1c3 /synapse
parentSpamChecker metrics (#12513) (diff)
downloadsynapse-aec69d2481e9ea1d8ea1c0ffce1706a65a7896a8.tar.xz
Another batch of type annotations (#12726)
Diffstat (limited to 'synapse')
-rw-r--r--synapse/handlers/e2e_keys.py29
-rw-r--r--synapse/http/connectproxyclient.py39
-rw-r--r--synapse/http/proxyagent.py2
-rw-r--r--synapse/logging/_remote.py20
-rw-r--r--synapse/logging/formatter.py14
-rw-r--r--synapse/logging/handlers.py4
-rw-r--r--synapse/logging/scopecontextmanager.py28
-rw-r--r--synapse/storage/background_updates.py19
-rw-r--r--synapse/types.py46
9 files changed, 122 insertions, 79 deletions
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index d6714228ef..e6c2cfb8c8 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -15,7 +15,7 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
 
 import attr
 from canonicaljson import encode_canonical_json
@@ -1105,22 +1105,19 @@ class E2eKeysHandler:
             # can request over federation
             raise NotFoundError("No %s key found for %s" % (key_type, user_id))
 
-        (
-            key,
-            key_id,
-            verify_key,
-        ) = await self._retrieve_cross_signing_keys_for_remote_user(user, key_type)
-
-        if key is None:
+        cross_signing_keys = await self._retrieve_cross_signing_keys_for_remote_user(
+            user, key_type
+        )
+        if cross_signing_keys is None:
             raise NotFoundError("No %s key found for %s" % (key_type, user_id))
 
-        return key, key_id, verify_key
+        return cross_signing_keys
 
     async def _retrieve_cross_signing_keys_for_remote_user(
         self,
         user: UserID,
         desired_key_type: str,
-    ) -> Tuple[Optional[dict], Optional[str], Optional[VerifyKey]]:
+    ) -> Optional[Tuple[Dict[str, Any], str, VerifyKey]]:
         """Queries cross-signing keys for a remote user and saves them to the database
 
         Only the key specified by `key_type` will be returned, while all retrieved keys
@@ -1146,12 +1143,10 @@ class E2eKeysHandler:
                 type(e),
                 e,
             )
-            return None, None, None
+            return None
 
         # Process each of the retrieved cross-signing keys
-        desired_key = None
-        desired_key_id = None
-        desired_verify_key = None
+        desired_key_data = None
         retrieved_device_ids = []
         for key_type in ["master", "self_signing"]:
             key_content = remote_result.get(key_type + "_key")
@@ -1196,9 +1191,7 @@ class E2eKeysHandler:
 
             # If this is the desired key type, save it and its ID/VerifyKey
             if key_type == desired_key_type:
-                desired_key = key_content
-                desired_verify_key = verify_key
-                desired_key_id = key_id
+                desired_key_data = key_content, key_id, verify_key
 
             # At the same time, store this key in the db for subsequent queries
             await self.store.set_e2e_cross_signing_key(
@@ -1212,7 +1205,7 @@ class E2eKeysHandler:
                 user.to_string(), retrieved_device_ids
             )
 
-        return desired_key, desired_key_id, desired_verify_key
+        return desired_key_data
 
 
 def _check_cross_signing_key(
diff --git a/synapse/http/connectproxyclient.py b/synapse/http/connectproxyclient.py
index 203e995bb7..23a60af171 100644
--- a/synapse/http/connectproxyclient.py
+++ b/synapse/http/connectproxyclient.py
@@ -14,15 +14,22 @@
 
 import base64
 import logging
-from typing import Optional
+from typing import Optional, Union
 
 import attr
 from zope.interface import implementer
 
 from twisted.internet import defer, protocol
 from twisted.internet.error import ConnectError
-from twisted.internet.interfaces import IReactorCore, IStreamClientEndpoint
+from twisted.internet.interfaces import (
+    IAddress,
+    IConnector,
+    IProtocol,
+    IReactorCore,
+    IStreamClientEndpoint,
+)
 from twisted.internet.protocol import ClientFactory, Protocol, connectionDone
+from twisted.python.failure import Failure
 from twisted.web import http
 
 logger = logging.getLogger(__name__)
@@ -81,14 +88,14 @@ class HTTPConnectProxyEndpoint:
         self._port = port
         self._proxy_creds = proxy_creds
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return "<HTTPConnectProxyEndpoint %s>" % (self._proxy_endpoint,)
 
     # Mypy encounters a false positive here: it complains that ClientFactory
     # is incompatible with IProtocolFactory. But ClientFactory inherits from
     # Factory, which implements IProtocolFactory. So I think this is a bug
     # in mypy-zope.
-    def connect(self, protocolFactory: ClientFactory):  # type: ignore[override]
+    def connect(self, protocolFactory: ClientFactory) -> "defer.Deferred[IProtocol]":  # type: ignore[override]
         f = HTTPProxiedClientFactory(
             self._host, self._port, protocolFactory, self._proxy_creds
         )
@@ -125,10 +132,10 @@ class HTTPProxiedClientFactory(protocol.ClientFactory):
         self.proxy_creds = proxy_creds
         self.on_connection: "defer.Deferred[None]" = defer.Deferred()
 
-    def startedConnecting(self, connector):
+    def startedConnecting(self, connector: IConnector) -> None:
         return self.wrapped_factory.startedConnecting(connector)
 
-    def buildProtocol(self, addr):
+    def buildProtocol(self, addr: IAddress) -> "HTTPConnectProtocol":
         wrapped_protocol = self.wrapped_factory.buildProtocol(addr)
         if wrapped_protocol is None:
             raise TypeError("buildProtocol produced None instead of a Protocol")
@@ -141,13 +148,13 @@ class HTTPProxiedClientFactory(protocol.ClientFactory):
             self.proxy_creds,
         )
 
-    def clientConnectionFailed(self, connector, reason):
+    def clientConnectionFailed(self, connector: IConnector, reason: Failure) -> None:
         logger.debug("Connection to proxy failed: %s", reason)
         if not self.on_connection.called:
             self.on_connection.errback(reason)
         return self.wrapped_factory.clientConnectionFailed(connector, reason)
 
-    def clientConnectionLost(self, connector, reason):
+    def clientConnectionLost(self, connector: IConnector, reason: Failure) -> None:
         logger.debug("Connection to proxy lost: %s", reason)
         if not self.on_connection.called:
             self.on_connection.errback(reason)
@@ -191,10 +198,10 @@ class HTTPConnectProtocol(protocol.Protocol):
         )
         self.http_setup_client.on_connected.addCallback(self.proxyConnected)
 
-    def connectionMade(self):
+    def connectionMade(self) -> None:
         self.http_setup_client.makeConnection(self.transport)
 
-    def connectionLost(self, reason=connectionDone):
+    def connectionLost(self, reason: Failure = connectionDone) -> None:
         if self.wrapped_protocol.connected:
             self.wrapped_protocol.connectionLost(reason)
 
@@ -203,7 +210,7 @@ class HTTPConnectProtocol(protocol.Protocol):
         if not self.connected_deferred.called:
             self.connected_deferred.errback(reason)
 
-    def proxyConnected(self, _):
+    def proxyConnected(self, _: Union[None, "defer.Deferred[None]"]) -> None:
         self.wrapped_protocol.makeConnection(self.transport)
 
         self.connected_deferred.callback(self.wrapped_protocol)
@@ -213,7 +220,7 @@ class HTTPConnectProtocol(protocol.Protocol):
         if buf:
             self.wrapped_protocol.dataReceived(buf)
 
-    def dataReceived(self, data: bytes):
+    def dataReceived(self, data: bytes) -> None:
         # if we've set up the HTTP protocol, we can send the data there
         if self.wrapped_protocol.connected:
             return self.wrapped_protocol.dataReceived(data)
@@ -243,7 +250,7 @@ class HTTPConnectSetupClient(http.HTTPClient):
         self.proxy_creds = proxy_creds
         self.on_connected: "defer.Deferred[None]" = defer.Deferred()
 
-    def connectionMade(self):
+    def connectionMade(self) -> None:
         logger.debug("Connected to proxy, sending CONNECT")
         self.sendCommand(b"CONNECT", b"%s:%d" % (self.host, self.port))
 
@@ -257,14 +264,14 @@ class HTTPConnectSetupClient(http.HTTPClient):
 
         self.endHeaders()
 
-    def handleStatus(self, version: bytes, status: bytes, message: bytes):
+    def handleStatus(self, version: bytes, status: bytes, message: bytes) -> None:
         logger.debug("Got Status: %s %s %s", status, message, version)
         if status != b"200":
             raise ProxyConnectError(f"Unexpected status on CONNECT: {status!s}")
 
-    def handleEndHeaders(self):
+    def handleEndHeaders(self) -> None:
         logger.debug("End Headers")
         self.on_connected.callback(None)
 
-    def handleResponse(self, body):
+    def handleResponse(self, body: bytes) -> None:
         pass
diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py
index a16dde2380..b2a50c9105 100644
--- a/synapse/http/proxyagent.py
+++ b/synapse/http/proxyagent.py
@@ -245,7 +245,7 @@ def http_proxy_endpoint(
     proxy: Optional[bytes],
     reactor: IReactorCore,
     tls_options_factory: Optional[IPolicyForHTTPS],
-    **kwargs,
+    **kwargs: object,
 ) -> Tuple[Optional[IStreamClientEndpoint], Optional[ProxyCredentials]]:
     """Parses an http proxy setting and returns an endpoint for the proxy
 
diff --git a/synapse/logging/_remote.py b/synapse/logging/_remote.py
index 475756f1db..5a61b21eaf 100644
--- a/synapse/logging/_remote.py
+++ b/synapse/logging/_remote.py
@@ -31,7 +31,11 @@ from twisted.internet.endpoints import (
     TCP4ClientEndpoint,
     TCP6ClientEndpoint,
 )
-from twisted.internet.interfaces import IPushProducer, IStreamClientEndpoint
+from twisted.internet.interfaces import (
+    IPushProducer,
+    IReactorTCP,
+    IStreamClientEndpoint,
+)
 from twisted.internet.protocol import Factory, Protocol
 from twisted.internet.tcp import Connection
 from twisted.python.failure import Failure
@@ -59,14 +63,14 @@ class LogProducer:
     _buffer: Deque[logging.LogRecord]
     _paused: bool = attr.ib(default=False, init=False)
 
-    def pauseProducing(self):
+    def pauseProducing(self) -> None:
         self._paused = True
 
-    def stopProducing(self):
+    def stopProducing(self) -> None:
         self._paused = True
         self._buffer = deque()
 
-    def resumeProducing(self):
+    def resumeProducing(self) -> None:
         # If we're already producing, nothing to do.
         self._paused = False
 
@@ -102,8 +106,8 @@ class RemoteHandler(logging.Handler):
         host: str,
         port: int,
         maximum_buffer: int = 1000,
-        level=logging.NOTSET,
-        _reactor=None,
+        level: int = logging.NOTSET,
+        _reactor: Optional[IReactorTCP] = None,
     ):
         super().__init__(level=level)
         self.host = host
@@ -118,7 +122,7 @@ class RemoteHandler(logging.Handler):
         if _reactor is None:
             from twisted.internet import reactor
 
-            _reactor = reactor
+            _reactor = reactor  # type: ignore[assignment]
 
         try:
             ip = ip_address(self.host)
@@ -139,7 +143,7 @@ class RemoteHandler(logging.Handler):
         self._stopping = False
         self._connect()
 
-    def close(self):
+    def close(self) -> None:
         self._stopping = True
         self._service.stopService()
 
diff --git a/synapse/logging/formatter.py b/synapse/logging/formatter.py
index c0f12ecd15..c88b8ae545 100644
--- a/synapse/logging/formatter.py
+++ b/synapse/logging/formatter.py
@@ -16,6 +16,8 @@
 import logging
 import traceback
 from io import StringIO
+from types import TracebackType
+from typing import Optional, Tuple, Type
 
 
 class LogFormatter(logging.Formatter):
@@ -28,10 +30,14 @@ class LogFormatter(logging.Formatter):
     where it was caught are logged).
     """
 
-    def __init__(self, *args, **kwargs):
-        super().__init__(*args, **kwargs)
-
-    def formatException(self, ei):
+    def formatException(
+        self,
+        ei: Tuple[
+            Optional[Type[BaseException]],
+            Optional[BaseException],
+            Optional[TracebackType],
+        ],
+    ) -> str:
         sio = StringIO()
         (typ, val, tb) = ei
 
diff --git a/synapse/logging/handlers.py b/synapse/logging/handlers.py
index 478b527494..dec2a2c3dd 100644
--- a/synapse/logging/handlers.py
+++ b/synapse/logging/handlers.py
@@ -49,7 +49,7 @@ class PeriodicallyFlushingMemoryHandler(MemoryHandler):
         )
         self._flushing_thread.start()
 
-        def on_reactor_running():
+        def on_reactor_running() -> None:
             self._reactor_started = True
 
         reactor_to_use: IReactorCore
@@ -74,7 +74,7 @@ class PeriodicallyFlushingMemoryHandler(MemoryHandler):
         else:
             return True
 
-    def _flush_periodically(self):
+    def _flush_periodically(self) -> None:
         """
         Whilst this handler is active, flush the handler periodically.
         """
diff --git a/synapse/logging/scopecontextmanager.py b/synapse/logging/scopecontextmanager.py
index d57e7c5324..a26a1a58e7 100644
--- a/synapse/logging/scopecontextmanager.py
+++ b/synapse/logging/scopecontextmanager.py
@@ -13,6 +13,8 @@
 # limitations under the License.import logging
 
 import logging
+from types import TracebackType
+from typing import Optional, Type
 
 from opentracing import Scope, ScopeManager
 
@@ -107,19 +109,26 @@ class _LogContextScope(Scope):
         and - if enter_logcontext was set - the logcontext is finished too.
     """
 
-    def __init__(self, manager, span, logcontext, enter_logcontext, finish_on_close):
+    def __init__(
+        self,
+        manager: LogContextScopeManager,
+        span,
+        logcontext,
+        enter_logcontext: bool,
+        finish_on_close: bool,
+    ):
         """
         Args:
-            manager (LogContextScopeManager):
+            manager:
                 the manager that is responsible for this scope.
             span (Span):
                 the opentracing span which this scope represents the local
                 lifetime for.
             logcontext (LogContext):
                 the logcontext to which this scope is attached.
-            enter_logcontext (Boolean):
+            enter_logcontext:
                 if True the logcontext will be exited when the scope is finished
-            finish_on_close (Boolean):
+            finish_on_close:
                 if True finish the span when the scope is closed
         """
         super().__init__(manager, span)
@@ -127,16 +136,21 @@ class _LogContextScope(Scope):
         self._finish_on_close = finish_on_close
         self._enter_logcontext = enter_logcontext
 
-    def __exit__(self, exc_type, value, traceback):
+    def __exit__(
+        self,
+        exc_type: Optional[Type[BaseException]],
+        value: Optional[BaseException],
+        traceback: Optional[TracebackType],
+    ) -> None:
         if exc_type == twisted.internet.defer._DefGen_Return:
             # filter out defer.returnValue() calls
             exc_type = value = traceback = None
         super().__exit__(exc_type, value, traceback)
 
-    def __str__(self):
+    def __str__(self) -> str:
         return f"Scope<{self.span}>"
 
-    def close(self):
+    def close(self) -> None:
         active_scope = self.manager.active
         if active_scope is not self:
             logger.error(
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 08c6eabc6d..c2bbbb574e 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -12,20 +12,24 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+from types import TracebackType
 from typing import (
     TYPE_CHECKING,
+    Any,
     AsyncContextManager,
     Awaitable,
     Callable,
     Dict,
     Iterable,
+    List,
     Optional,
+    Type,
 )
 
 import attr
 
 from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage.types import Connection
+from synapse.storage.types import Connection, Cursor
 from synapse.types import JsonDict
 from synapse.util import Clock, json_encoder
 
@@ -74,7 +78,12 @@ class _BackgroundUpdateContextManager:
 
         return self._update_duration_ms
 
-    async def __aexit__(self, *exc) -> None:
+    async def __aexit__(
+        self,
+        exc_type: Optional[Type[BaseException]],
+        exc: Optional[BaseException],
+        tb: Optional[TracebackType],
+    ) -> None:
         pass
 
 
@@ -352,7 +361,7 @@ class BackgroundUpdater:
             True if we have finished running all the background updates, otherwise False
         """
 
-        def get_background_updates_txn(txn):
+        def get_background_updates_txn(txn: Cursor) -> List[Dict[str, Any]]:
             txn.execute(
                 """
                 SELECT update_name, depends_on FROM background_updates
@@ -469,7 +478,7 @@ class BackgroundUpdater:
         self,
         update_name: str,
         update_handler: Callable[[JsonDict, int], Awaitable[int]],
-    ):
+    ) -> None:
         """Register a handler for doing a background update.
 
         The handler should take two arguments:
@@ -603,7 +612,7 @@ class BackgroundUpdater:
         else:
             runner = create_index_sqlite
 
-        async def updater(progress, batch_size):
+        async def updater(progress: JsonDict, batch_size: int) -> int:
             if runner is not None:
                 logger.info("Adding index %s to %s", index_name, table)
                 await self.db_pool.runWithConnection(runner)
diff --git a/synapse/types.py b/synapse/types.py
index 9ac688b23b..325332a6e0 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -24,6 +24,7 @@ from typing import (
     Mapping,
     Match,
     MutableMapping,
+    NoReturn,
     Optional,
     Set,
     Tuple,
@@ -35,6 +36,7 @@ from typing import (
 import attr
 from frozendict import frozendict
 from signedjson.key import decode_verify_key_bytes
+from signedjson.types import VerifyKey
 from typing_extensions import TypedDict
 from unpaddedbase64 import decode_base64
 from zope.interface import Interface
@@ -55,6 +57,7 @@ from synapse.util.stringutils import parse_and_validate_server_name
 if TYPE_CHECKING:
     from synapse.appservice.api import ApplicationService
     from synapse.storage.databases.main import DataStore, PurgeEventsStore
+    from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore
 
 # Define a state map type from type/state_key to T (usually an event ID or
 # event)
@@ -114,7 +117,7 @@ class Requester:
     app_service: Optional["ApplicationService"]
     authenticated_entity: str
 
-    def serialize(self):
+    def serialize(self) -> Dict[str, Any]:
         """Converts self to a type that can be serialized as JSON, and then
         deserialized by `deserialize`
 
@@ -132,7 +135,9 @@ class Requester:
         }
 
     @staticmethod
-    def deserialize(store, input):
+    def deserialize(
+        store: "ApplicationServiceWorkerStore", input: Dict[str, Any]
+    ) -> "Requester":
         """Converts a dict that was produced by `serialize` back into a
         Requester.
 
@@ -236,10 +241,10 @@ class DomainSpecificString(metaclass=abc.ABCMeta):
     domain: str
 
     # Because this is a frozen class, it is deeply immutable.
-    def __copy__(self):
+    def __copy__(self: DS) -> DS:
         return self
 
-    def __deepcopy__(self, memo):
+    def __deepcopy__(self: DS, memo: Dict[str, object]) -> DS:
         return self
 
     @classmethod
@@ -729,12 +734,14 @@ class StreamToken:
         )
 
     @property
-    def room_stream_id(self):
+    def room_stream_id(self) -> int:
         return self.room_key.stream
 
-    def copy_and_advance(self, key, new_value) -> "StreamToken":
+    def copy_and_advance(self, key: str, new_value: Any) -> "StreamToken":
         """Advance the given key in the token to a new value if and only if the
         new value is after the old value.
+
+        :raises TypeError: if `key` is not the one of the keys tracked by a StreamToken.
         """
         if key == "room_key":
             new_token = self.copy_and_replace(
@@ -751,7 +758,7 @@ class StreamToken:
         else:
             return self
 
-    def copy_and_replace(self, key, new_value) -> "StreamToken":
+    def copy_and_replace(self, key: str, new_value: Any) -> "StreamToken":
         return attr.evolve(self, **{key: new_value})
 
 
@@ -793,14 +800,14 @@ class ThirdPartyInstanceID:
     # Deny iteration because it will bite you if you try to create a singleton
     # set by:
     #    users = set(user)
-    def __iter__(self):
+    def __iter__(self) -> NoReturn:
         raise ValueError("Attempted to iterate a %s" % (type(self).__name__,))
 
     # Because this class is a frozen class, it is deeply immutable.
-    def __copy__(self):
+    def __copy__(self) -> "ThirdPartyInstanceID":
         return self
 
-    def __deepcopy__(self, memo):
+    def __deepcopy__(self, memo: Dict[str, object]) -> "ThirdPartyInstanceID":
         return self
 
     @classmethod
@@ -852,25 +859,28 @@ class DeviceListUpdates:
         return bool(self.changed or self.left)
 
 
-def get_verify_key_from_cross_signing_key(key_info):
+def get_verify_key_from_cross_signing_key(
+    key_info: Mapping[str, Any]
+) -> Tuple[str, VerifyKey]:
     """Get the key ID and signedjson verify key from a cross-signing key dict
 
     Args:
-        key_info (dict): a cross-signing key dict, which must have a "keys"
+        key_info: a cross-signing key dict, which must have a "keys"
             property that has exactly one item in it
 
     Returns:
-        (str, VerifyKey): the key ID and verify key for the cross-signing key
+        the key ID and verify key for the cross-signing key
     """
-    # make sure that exactly one key is provided
+    # make sure that a `keys` field is provided
     if "keys" not in key_info:
         raise ValueError("Invalid key")
     keys = key_info["keys"]
-    if len(keys) != 1:
-        raise ValueError("Invalid key")
-    # and return that one key
-    for key_id, key_data in keys.items():
+    # and that it contains exactly one key
+    if len(keys) == 1:
+        key_id, key_data = next(iter(keys.items()))
         return key_id, decode_verify_key_bytes(key_id, decode_base64(key_data))
+    else:
+        raise ValueError("Invalid key")
 
 
 @attr.s(auto_attribs=True, frozen=True, slots=True)