summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/11953.misc1
-rw-r--r--mypy.ini6
-rw-r--r--synapse/event_auth.py4
-rw-r--r--synapse/handlers/oidc.py4
-rw-r--r--synapse/http/client.py11
-rw-r--r--synapse/http/matrixfederationclient.py3
-rw-r--r--synapse/notifier.py43
-rw-r--r--synapse/server.py8
-rw-r--r--tests/handlers/test_oidc.py9
9 files changed, 48 insertions, 41 deletions
diff --git a/changelog.d/11953.misc b/changelog.d/11953.misc
new file mode 100644
index 0000000000..d44571b731
--- /dev/null
+++ b/changelog.d/11953.misc
@@ -0,0 +1 @@
+Add missing type hints.
diff --git a/mypy.ini b/mypy.ini
index cd28ac0dd2..63848d664c 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -142,6 +142,9 @@ disallow_untyped_defs = True
 [mypy-synapse.crypto.*]
 disallow_untyped_defs = True
 
+[mypy-synapse.event_auth]
+disallow_untyped_defs = True
+
 [mypy-synapse.events.*]
 disallow_untyped_defs = True
 
@@ -166,6 +169,9 @@ disallow_untyped_defs = True
 [mypy-synapse.module_api.*]
 disallow_untyped_defs = True
 
+[mypy-synapse.notifier]
+disallow_untyped_defs = True
+
 [mypy-synapse.push.*]
 disallow_untyped_defs = True
 
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index e885961698..19b55a9559 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -763,7 +763,9 @@ def get_named_level(auth_events: StateMap[EventBase], name: str, default: int) -
         return default
 
 
-def _verify_third_party_invite(event: EventBase, auth_events: StateMap[EventBase]):
+def _verify_third_party_invite(
+    event: EventBase, auth_events: StateMap[EventBase]
+) -> bool:
     """
     Validates that the invite event is authorized by a previous third-party invite.
 
diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py
index deb3539751..8f71d975e9 100644
--- a/synapse/handlers/oidc.py
+++ b/synapse/handlers/oidc.py
@@ -544,9 +544,9 @@ class OidcProvider:
         """
         metadata = await self.load_metadata()
         token_endpoint = metadata.get("token_endpoint")
-        raw_headers = {
+        raw_headers: Dict[str, str] = {
             "Content-Type": "application/x-www-form-urlencoded",
-            "User-Agent": self._http_client.user_agent,
+            "User-Agent": self._http_client.user_agent.decode("ascii"),
             "Accept": "application/json",
         }
 
diff --git a/synapse/http/client.py b/synapse/http/client.py
index d617055617..c01d2326cf 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -322,21 +322,20 @@ class SimpleHttpClient:
         self._ip_whitelist = ip_whitelist
         self._ip_blacklist = ip_blacklist
         self._extra_treq_args = treq_args or {}
-
-        self.user_agent = hs.version_string
         self.clock = hs.get_clock()
+
+        user_agent = hs.version_string
         if hs.config.server.user_agent_suffix:
-            self.user_agent = "%s %s" % (
-                self.user_agent,
+            user_agent = "%s %s" % (
+                user_agent,
                 hs.config.server.user_agent_suffix,
             )
+        self.user_agent = user_agent.encode("ascii")
 
         # We use this for our body producers to ensure that they use the correct
         # reactor.
         self._cooperator = Cooperator(scheduler=_make_scheduler(hs.get_reactor()))
 
-        self.user_agent = self.user_agent.encode("ascii")
-
         if self._ip_blacklist:
             # If we have an IP blacklist, we need to use a DNS resolver which
             # filters out blacklisted IP addresses, to prevent DNS rebinding.
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 2e668363b2..c5f8fcbb2a 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -334,12 +334,11 @@ class MatrixFederationHttpClient:
         user_agent = hs.version_string
         if hs.config.server.user_agent_suffix:
             user_agent = "%s %s" % (user_agent, hs.config.server.user_agent_suffix)
-        user_agent = user_agent.encode("ascii")
 
         federation_agent = MatrixFederationAgent(
             self.reactor,
             tls_client_options_factory,
-            user_agent,
+            user_agent.encode("ascii"),
             hs.config.server.federation_ip_range_whitelist,
             hs.config.server.federation_ip_range_blacklist,
         )
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 5988c67d90..e0fad2da66 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -14,6 +14,7 @@
 
 import logging
 from typing import (
+    TYPE_CHECKING,
     Awaitable,
     Callable,
     Collection,
@@ -32,7 +33,6 @@ from prometheus_client import Counter
 
 from twisted.internet import defer
 
-import synapse.server
 from synapse.api.constants import EventTypes, HistoryVisibility, Membership
 from synapse.api.errors import AuthError
 from synapse.events import EventBase
@@ -53,6 +53,9 @@ from synapse.util.async_helpers import ObservableDeferred, timeout_deferred
 from synapse.util.metrics import Measure
 from synapse.visibility import filter_events_for_client
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 notified_events_counter = Counter("synapse_notifier_notified_events", "")
@@ -82,7 +85,7 @@ class _NotificationListener:
 
     __slots__ = ["deferred"]
 
-    def __init__(self, deferred):
+    def __init__(self, deferred: "defer.Deferred"):
         self.deferred = deferred
 
 
@@ -124,7 +127,7 @@ class _NotifierUserStream:
         stream_key: str,
         stream_id: Union[int, RoomStreamToken],
         time_now_ms: int,
-    ):
+    ) -> None:
         """Notify any listeners for this user of a new event from an
         event source.
         Args:
@@ -152,7 +155,7 @@ class _NotifierUserStream:
             self.notify_deferred = ObservableDeferred(defer.Deferred())
             noify_deferred.callback(self.current_token)
 
-    def remove(self, notifier: "Notifier"):
+    def remove(self, notifier: "Notifier") -> None:
         """Remove this listener from all the indexes in the Notifier
         it knows about.
         """
@@ -188,7 +191,7 @@ class EventStreamResult:
     start_token: StreamToken
     end_token: StreamToken
 
-    def __bool__(self):
+    def __bool__(self) -> bool:
         return bool(self.events)
 
 
@@ -212,7 +215,7 @@ class Notifier:
 
     UNUSED_STREAM_EXPIRY_MS = 10 * 60 * 1000
 
-    def __init__(self, hs: "synapse.server.HomeServer"):
+    def __init__(self, hs: "HomeServer"):
         self.user_to_user_stream: Dict[str, _NotifierUserStream] = {}
         self.room_to_user_streams: Dict[str, Set[_NotifierUserStream]] = {}
 
@@ -248,7 +251,7 @@ class Notifier:
         # This is not a very cheap test to perform, but it's only executed
         # when rendering the metrics page, which is likely once per minute at
         # most when scraping it.
-        def count_listeners():
+        def count_listeners() -> int:
             all_user_streams: Set[_NotifierUserStream] = set()
 
             for streams in list(self.room_to_user_streams.values()):
@@ -270,7 +273,7 @@ class Notifier:
             "synapse_notifier_users", "", [], lambda: len(self.user_to_user_stream)
         )
 
-    def add_replication_callback(self, cb: Callable[[], None]):
+    def add_replication_callback(self, cb: Callable[[], None]) -> None:
         """Add a callback that will be called when some new data is available.
         Callback is not given any arguments. It should *not* return a Deferred - if
         it needs to do any asynchronous work, a background thread should be started and
@@ -284,7 +287,7 @@ class Notifier:
         event_pos: PersistedEventPosition,
         max_room_stream_token: RoomStreamToken,
         extra_users: Optional[Collection[UserID]] = None,
-    ):
+    ) -> None:
         """Unwraps event and calls `on_new_room_event_args`."""
         await self.on_new_room_event_args(
             event_pos=event_pos,
@@ -307,7 +310,7 @@ class Notifier:
         event_pos: PersistedEventPosition,
         max_room_stream_token: RoomStreamToken,
         extra_users: Optional[Collection[UserID]] = None,
-    ):
+    ) -> None:
         """Used by handlers to inform the notifier something has happened
         in the room, room event wise.
 
@@ -338,7 +341,9 @@ class Notifier:
 
         self.notify_replication()
 
-    def _notify_pending_new_room_events(self, max_room_stream_token: RoomStreamToken):
+    def _notify_pending_new_room_events(
+        self, max_room_stream_token: RoomStreamToken
+    ) -> None:
         """Notify for the room events that were queued waiting for a previous
         event to be persisted.
         Args:
@@ -374,7 +379,7 @@ class Notifier:
             )
             self._on_updated_room_token(max_room_stream_token)
 
-    def _on_updated_room_token(self, max_room_stream_token: RoomStreamToken):
+    def _on_updated_room_token(self, max_room_stream_token: RoomStreamToken) -> None:
         """Poke services that might care that the room position has been
         updated.
         """
@@ -386,13 +391,13 @@ class Notifier:
         if self.federation_sender:
             self.federation_sender.notify_new_events(max_room_stream_token)
 
-    def _notify_app_services(self, max_room_stream_token: RoomStreamToken):
+    def _notify_app_services(self, max_room_stream_token: RoomStreamToken) -> None:
         try:
             self.appservice_handler.notify_interested_services(max_room_stream_token)
         except Exception:
             logger.exception("Error notifying application services of event")
 
-    def _notify_pusher_pool(self, max_room_stream_token: RoomStreamToken):
+    def _notify_pusher_pool(self, max_room_stream_token: RoomStreamToken) -> None:
         try:
             self._pusher_pool.on_new_notifications(max_room_stream_token)
         except Exception:
@@ -475,8 +480,8 @@ class Notifier:
         user_id: str,
         timeout: int,
         callback: Callable[[StreamToken, StreamToken], Awaitable[T]],
-        room_ids=None,
-        from_token=StreamToken.START,
+        room_ids: Optional[Collection[str]] = None,
+        from_token: StreamToken = StreamToken.START,
     ) -> T:
         """Wait until the callback returns a non empty response or the
         timeout fires.
@@ -700,14 +705,14 @@ class Notifier:
         for expired_stream in expired_streams:
             expired_stream.remove(self)
 
-    def _register_with_keys(self, user_stream: _NotifierUserStream):
+    def _register_with_keys(self, user_stream: _NotifierUserStream) -> None:
         self.user_to_user_stream[user_stream.user_id] = user_stream
 
         for room in user_stream.rooms:
             s = self.room_to_user_streams.setdefault(room, set())
             s.add(user_stream)
 
-    def _user_joined_room(self, user_id: str, room_id: str):
+    def _user_joined_room(self, user_id: str, room_id: str) -> None:
         new_user_stream = self.user_to_user_stream.get(user_id)
         if new_user_stream is not None:
             room_streams = self.room_to_user_streams.setdefault(room_id, set())
@@ -719,7 +724,7 @@ class Notifier:
         for cb in self.replication_callbacks:
             cb()
 
-    def notify_remote_server_up(self, server: str):
+    def notify_remote_server_up(self, server: str) -> None:
         """Notify any replication that a remote server has come back up"""
         # We call federation_sender directly rather than registering as a
         # callback as a) we already have a reference to it and b) it introduces
diff --git a/synapse/server.py b/synapse/server.py
index 3032f0b738..564afdcb96 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -233,8 +233,8 @@ class HomeServer(metaclass=abc.ABCMeta):
         self,
         hostname: str,
         config: HomeServerConfig,
-        reactor=None,
-        version_string="Synapse",
+        reactor: Optional[ISynapseReactor] = None,
+        version_string: str = "Synapse",
     ):
         """
         Args:
@@ -244,7 +244,7 @@ class HomeServer(metaclass=abc.ABCMeta):
         if not reactor:
             from twisted.internet import reactor as _reactor
 
-            reactor = _reactor
+            reactor = cast(ISynapseReactor, _reactor)
 
         self._reactor = reactor
         self.hostname = hostname
@@ -264,7 +264,7 @@ class HomeServer(metaclass=abc.ABCMeta):
         self._module_web_resources: Dict[str, Resource] = {}
         self._module_web_resources_consumed = False
 
-    def register_module_web_resource(self, path: str, resource: Resource):
+    def register_module_web_resource(self, path: str, resource: Resource) -> None:
         """Allows a module to register a web resource to be served at the given path.
 
         If multiple modules register a resource for the same path, the module that
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index cfe3de5266..a552d8182e 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -155,7 +155,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
     def make_homeserver(self, reactor, clock):
         self.http_client = Mock(spec=["get_json"])
         self.http_client.get_json.side_effect = get_json
-        self.http_client.user_agent = "Synapse Test"
+        self.http_client.user_agent = b"Synapse Test"
 
         hs = self.setup_test_homeserver(proxied_http_client=self.http_client)
 
@@ -438,12 +438,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
         state = "state"
         nonce = "nonce"
         client_redirect_url = "http://client/redirect"
-        user_agent = "Browser"
         ip_address = "10.0.0.1"
         session = self._generate_oidc_session_token(state, nonce, client_redirect_url)
-        request = _build_callback_request(
-            code, state, session, user_agent=user_agent, ip_address=ip_address
-        )
+        request = _build_callback_request(code, state, session, ip_address=ip_address)
 
         self.get_success(self.handler.handle_oidc_callback(request))
 
@@ -1274,7 +1271,6 @@ def _build_callback_request(
     code: str,
     state: str,
     session: str,
-    user_agent: str = "Browser",
     ip_address: str = "10.0.0.1",
 ):
     """Builds a fake SynapseRequest to mock the browser callback
@@ -1289,7 +1285,6 @@ def _build_callback_request(
            query param. Should be the same as was embedded in the session in
            _build_oidc_session.
         session: the "session" which would have been passed around in the cookie.
-        user_agent: the user-agent to present
         ip_address: the IP address to pretend the request came from
     """
     request = Mock(