summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--changelog.d/8655.misc1
-rw-r--r--mypy.ini4
-rw-r--r--synapse/handlers/appservice.py75
-rw-r--r--synapse/handlers/auth.py23
-rw-r--r--synapse/storage/databases/main/appservice.py98
5 files changed, 122 insertions, 79 deletions
diff --git a/changelog.d/8655.misc b/changelog.d/8655.misc
new file mode 100644
index 0000000000..b588bdd3e2
--- /dev/null
+++ b/changelog.d/8655.misc
@@ -0,0 +1 @@
+Add more type hints to the application services code.
diff --git a/mypy.ini b/mypy.ini
index 1fbd8decf8..1ece2ba082 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -57,6 +57,7 @@ files =
   synapse/server_notices,
   synapse/spam_checker_api,
   synapse/state,
+  synapse/storage/databases/main/appservice.py,
   synapse/storage/databases/main/events.py,
   synapse/storage/databases/main/registration.py,
   synapse/storage/databases/main/stream.py,
@@ -82,6 +83,9 @@ ignore_missing_imports = True
 [mypy-zope]
 ignore_missing_imports = True
 
+[mypy-bcrypt]
+ignore_missing_imports = True
+
 [mypy-constantly]
 ignore_missing_imports = True
 
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 3ed29a2c16..9fc8444228 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -12,9 +12,8 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import logging
-from typing import Dict, List, Optional, Union
+from typing import TYPE_CHECKING, Dict, List, Optional, Union
 
 from prometheus_client import Counter
 
@@ -34,16 +33,20 @@ from synapse.metrics.background_process_metrics import (
     run_as_background_process,
     wrap_as_background_process,
 )
-from synapse.types import Collection, JsonDict, RoomStreamToken, UserID
+from synapse.storage.databases.main.directory import RoomAliasMapping
+from synapse.types import Collection, JsonDict, RoomAlias, RoomStreamToken, UserID
 from synapse.util.metrics import Measure
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 events_processed_counter = Counter("synapse_handlers_appservice_events_processed", "")
 
 
 class ApplicationServicesHandler:
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
         self.is_mine_id = hs.is_mine_id
         self.appservice_api = hs.get_application_service_api()
@@ -247,7 +250,9 @@ class ApplicationServicesHandler:
                         service, "presence", new_token
                     )
 
-    async def _handle_typing(self, service: ApplicationService, new_token: int):
+    async def _handle_typing(
+        self, service: ApplicationService, new_token: int
+    ) -> List[JsonDict]:
         typing_source = self.event_sources.sources["typing"]
         # Get the typing events from just before current
         typing, _ = await typing_source.get_new_events_as(
@@ -259,7 +264,7 @@ class ApplicationServicesHandler:
         )
         return typing
 
-    async def _handle_receipts(self, service: ApplicationService):
+    async def _handle_receipts(self, service: ApplicationService) -> List[JsonDict]:
         from_key = await self.store.get_type_stream_id_for_appservice(
             service, "read_receipt"
         )
@@ -271,7 +276,7 @@ class ApplicationServicesHandler:
 
     async def _handle_presence(
         self, service: ApplicationService, users: Collection[Union[str, UserID]]
-    ):
+    ) -> List[JsonDict]:
         events = []  # type: List[JsonDict]
         presence_source = self.event_sources.sources["presence"]
         from_key = await self.store.get_type_stream_id_for_appservice(
@@ -301,11 +306,11 @@ class ApplicationServicesHandler:
 
         return events
 
-    async def query_user_exists(self, user_id):
+    async def query_user_exists(self, user_id: str) -> bool:
         """Check if any application service knows this user_id exists.
 
         Args:
-            user_id(str): The user to query if they exist on any AS.
+            user_id: The user to query if they exist on any AS.
         Returns:
             True if this user exists on at least one application service.
         """
@@ -316,11 +321,13 @@ class ApplicationServicesHandler:
                 return True
         return False
 
-    async def query_room_alias_exists(self, room_alias):
+    async def query_room_alias_exists(
+        self, room_alias: RoomAlias
+    ) -> Optional[RoomAliasMapping]:
         """Check if an application service knows this room alias exists.
 
         Args:
-            room_alias(RoomAlias): The room alias to query.
+            room_alias: The room alias to query.
         Returns:
             namedtuple: with keys "room_id" and "servers" or None if no
             association can be found.
@@ -336,10 +343,13 @@ class ApplicationServicesHandler:
             )
             if is_known_alias:
                 # the alias exists now so don't query more ASes.
-                result = await self.store.get_association_from_room_alias(room_alias)
-                return result
+                return await self.store.get_association_from_room_alias(room_alias)
+
+        return None
 
-    async def query_3pe(self, kind, protocol, fields):
+    async def query_3pe(
+        self, kind: str, protocol: str, fields: Dict[bytes, List[bytes]]
+    ) -> List[JsonDict]:
         services = self._get_services_for_3pn(protocol)
 
         results = await make_deferred_yieldable(
@@ -361,7 +371,9 @@ class ApplicationServicesHandler:
 
         return ret
 
-    async def get_3pe_protocols(self, only_protocol=None):
+    async def get_3pe_protocols(
+        self, only_protocol: Optional[str] = None
+    ) -> Dict[str, JsonDict]:
         services = self.store.get_app_services()
         protocols = {}  # type: Dict[str, List[JsonDict]]
 
@@ -379,7 +391,7 @@ class ApplicationServicesHandler:
                 if info is not None:
                     protocols[p].append(info)
 
-        def _merge_instances(infos):
+        def _merge_instances(infos: List[JsonDict]) -> JsonDict:
             if not infos:
                 return {}
 
@@ -394,19 +406,17 @@ class ApplicationServicesHandler:
 
             return combined
 
-        for p in protocols.keys():
-            protocols[p] = _merge_instances(protocols[p])
+        return {p: _merge_instances(protocols[p]) for p in protocols.keys()}
 
-        return protocols
-
-    async def _get_services_for_event(self, event):
+    async def _get_services_for_event(
+        self, event: EventBase
+    ) -> List[ApplicationService]:
         """Retrieve a list of application services interested in this event.
 
         Args:
-            event(Event): The event to check. Can be None if alias_list is not.
+            event: The event to check. Can be None if alias_list is not.
         Returns:
-            list<ApplicationService>: A list of services interested in this
-            event based on the service regex.
+            A list of services interested in this event based on the service regex.
         """
         services = self.store.get_app_services()
 
@@ -420,17 +430,15 @@ class ApplicationServicesHandler:
 
         return interested_list
 
-    def _get_services_for_user(self, user_id):
+    def _get_services_for_user(self, user_id: str) -> List[ApplicationService]:
         services = self.store.get_app_services()
-        interested_list = [s for s in services if (s.is_interested_in_user(user_id))]
-        return interested_list
+        return [s for s in services if (s.is_interested_in_user(user_id))]
 
-    def _get_services_for_3pn(self, protocol):
+    def _get_services_for_3pn(self, protocol: str) -> List[ApplicationService]:
         services = self.store.get_app_services()
-        interested_list = [s for s in services if s.is_interested_in_protocol(protocol)]
-        return interested_list
+        return [s for s in services if s.is_interested_in_protocol(protocol)]
 
-    async def _is_unknown_user(self, user_id):
+    async def _is_unknown_user(self, user_id: str) -> bool:
         if not self.is_mine_id(user_id):
             # we don't know if they are unknown or not since it isn't one of our
             # users. We can't poke ASes.
@@ -445,9 +453,8 @@ class ApplicationServicesHandler:
         service_list = [s for s in services if s.sender == user_id]
         return len(service_list) == 0
 
-    async def _check_user_exists(self, user_id):
+    async def _check_user_exists(self, user_id: str) -> bool:
         unknown_user = await self._is_unknown_user(user_id)
         if unknown_user:
-            exists = await self.query_user_exists(user_id)
-            return exists
+            return await self.query_user_exists(user_id)
         return True
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index dd14ab69d7..276594f3d9 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -18,10 +18,20 @@ import logging
 import time
 import unicodedata
 import urllib.parse
-from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Callable,
+    Dict,
+    Iterable,
+    List,
+    Optional,
+    Tuple,
+    Union,
+)
 
 import attr
-import bcrypt  # type: ignore[import]
+import bcrypt
 import pymacaroons
 
 from synapse.api.constants import LoginType
@@ -49,6 +59,9 @@ from synapse.util.threepids import canonicalise_email
 
 from ._base import BaseHandler
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -149,11 +162,7 @@ class SsoLoginExtraAttributes:
 class AuthHandler(BaseHandler):
     SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
 
-    def __init__(self, hs):
-        """
-        Args:
-            hs (synapse.server.HomeServer):
-        """
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
         self.checkers = {}  # type: Dict[str, UserInteractiveAuthChecker]
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 637a938bac..26eef6eb61 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -15,21 +15,31 @@
 # limitations under the License.
 import logging
 import re
-from typing import List
+from typing import TYPE_CHECKING, List, Optional, Pattern, Tuple
 
-from synapse.appservice import ApplicationService, AppServiceTransaction
+from synapse.appservice import (
+    ApplicationService,
+    ApplicationServiceState,
+    AppServiceTransaction,
+)
 from synapse.config.appservice import load_appservices
 from synapse.events import EventBase
 from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.storage.database import DatabasePool
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
+from synapse.storage.types import Connection
 from synapse.types import JsonDict
 from synapse.util import json_encoder
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
-def _make_exclusive_regex(services_cache):
+def _make_exclusive_regex(
+    services_cache: List[ApplicationService],
+) -> Optional[Pattern]:
     # We precompile a regex constructed from all the regexes that the AS's
     # have registered for exclusive users.
     exclusive_user_regexes = [
@@ -39,17 +49,19 @@ def _make_exclusive_regex(services_cache):
     ]
     if exclusive_user_regexes:
         exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes)
-        exclusive_user_regex = re.compile(exclusive_user_regex)
+        exclusive_user_pattern = re.compile(
+            exclusive_user_regex
+        )  # type: Optional[Pattern]
     else:
         # We handle this case specially otherwise the constructed regex
         # will always match
-        exclusive_user_regex = None
+        exclusive_user_pattern = None
 
-    return exclusive_user_regex
+    return exclusive_user_pattern
 
 
 class ApplicationServiceWorkerStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
         self.services_cache = load_appservices(
             hs.hostname, hs.config.app_service_config_files
         )
@@ -60,7 +72,7 @@ class ApplicationServiceWorkerStore(SQLBaseStore):
     def get_app_services(self):
         return self.services_cache
 
-    def get_if_app_services_interested_in_user(self, user_id):
+    def get_if_app_services_interested_in_user(self, user_id: str) -> bool:
         """Check if the user is one associated with an app service (exclusively)
         """
         if self.exclusive_user_regex:
@@ -68,7 +80,7 @@ class ApplicationServiceWorkerStore(SQLBaseStore):
         else:
             return False
 
-    def get_app_service_by_user_id(self, user_id):
+    def get_app_service_by_user_id(self, user_id: str) -> Optional[ApplicationService]:
         """Retrieve an application service from their user ID.
 
         All application services have associated with them a particular user ID.
@@ -77,35 +89,35 @@ class ApplicationServiceWorkerStore(SQLBaseStore):
         a user ID to an application service.
 
         Args:
-            user_id(str): The user ID to see if it is an application service.
+            user_id: The user ID to see if it is an application service.
         Returns:
-            synapse.appservice.ApplicationService or None.
+            The application service or None.
         """
         for service in self.services_cache:
             if service.sender == user_id:
                 return service
         return None
 
-    def get_app_service_by_token(self, token):
+    def get_app_service_by_token(self, token: str) -> Optional[ApplicationService]:
         """Get the application service with the given appservice token.
 
         Args:
-            token (str): The application service token.
+            token: The application service token.
         Returns:
-            synapse.appservice.ApplicationService or None.
+            The application service or None.
         """
         for service in self.services_cache:
             if service.token == token:
                 return service
         return None
 
-    def get_app_service_by_id(self, as_id):
+    def get_app_service_by_id(self, as_id: str) -> Optional[ApplicationService]:
         """Get the application service with the given appservice ID.
 
         Args:
-            as_id (str): The application service ID.
+            as_id: The application service ID.
         Returns:
-            synapse.appservice.ApplicationService or None.
+            The application service or None.
         """
         for service in self.services_cache:
             if service.id == as_id:
@@ -124,11 +136,13 @@ class ApplicationServiceStore(ApplicationServiceWorkerStore):
 class ApplicationServiceTransactionWorkerStore(
     ApplicationServiceWorkerStore, EventsWorkerStore
 ):
-    async def get_appservices_by_state(self, state):
+    async def get_appservices_by_state(
+        self, state: ApplicationServiceState
+    ) -> List[ApplicationService]:
         """Get a list of application services based on their state.
 
         Args:
-            state(ApplicationServiceState): The state to filter on.
+            state: The state to filter on.
         Returns:
             A list of ApplicationServices, which may be empty.
         """
@@ -145,13 +159,15 @@ class ApplicationServiceTransactionWorkerStore(
                     services.append(service)
         return services
 
-    async def get_appservice_state(self, service):
+    async def get_appservice_state(
+        self, service: ApplicationService
+    ) -> Optional[ApplicationServiceState]:
         """Get the application service state.
 
         Args:
-            service(ApplicationService): The service whose state to set.
+            service: The service whose state to set.
         Returns:
-            An ApplicationServiceState.
+            An ApplicationServiceState or none.
         """
         result = await self.db_pool.simple_select_one(
             "application_services_state",
@@ -164,12 +180,14 @@ class ApplicationServiceTransactionWorkerStore(
             return result.get("state")
         return None
 
-    async def set_appservice_state(self, service, state) -> None:
+    async def set_appservice_state(
+        self, service: ApplicationService, state: ApplicationServiceState
+    ) -> None:
         """Set the application service state.
 
         Args:
-            service(ApplicationService): The service whose state to set.
-            state(ApplicationServiceState): The connectivity state to apply.
+            service: The service whose state to set.
+            state: The connectivity state to apply.
         """
         await self.db_pool.simple_upsert(
             "application_services_state", {"as_id": service.id}, {"state": state}
@@ -226,13 +244,14 @@ class ApplicationServiceTransactionWorkerStore(
             "create_appservice_txn", _create_appservice_txn
         )
 
-    async def complete_appservice_txn(self, txn_id, service) -> None:
+    async def complete_appservice_txn(
+        self, txn_id: int, service: ApplicationService
+    ) -> None:
         """Completes an application service transaction.
 
         Args:
-            txn_id(str): The transaction ID being completed.
-            service(ApplicationService): The application service which was sent
-            this transaction.
+            txn_id: The transaction ID being completed.
+            service: The application service which was sent this transaction.
         """
         txn_id = int(txn_id)
 
@@ -242,7 +261,7 @@ class ApplicationServiceTransactionWorkerStore(
             # has probably missed some events), so whine loudly but still continue,
             # since it shouldn't fail completion of the transaction.
             last_txn_id = self._get_last_txn(txn, service.id)
-            if (last_txn_id + 1) != txn_id:
+            if (txn_id + 1) != txn_id:
                 logger.error(
                     "appservice: Completing a transaction which has an ID > 1 from "
                     "the last ID sent to this AS. We've either dropped events or "
@@ -272,12 +291,13 @@ class ApplicationServiceTransactionWorkerStore(
             "complete_appservice_txn", _complete_appservice_txn
         )
 
-    async def get_oldest_unsent_txn(self, service):
-        """Get the oldest transaction which has not been sent for this
-        service.
+    async def get_oldest_unsent_txn(
+        self, service: ApplicationService
+    ) -> Optional[AppServiceTransaction]:
+        """Get the oldest transaction which has not been sent for this service.
 
         Args:
-            service(ApplicationService): The app service to get the oldest txn.
+            service: The app service to get the oldest txn.
         Returns:
             An AppServiceTransaction or None.
         """
@@ -313,7 +333,7 @@ class ApplicationServiceTransactionWorkerStore(
             service=service, id=entry["txn_id"], events=events, ephemeral=[]
         )
 
-    def _get_last_txn(self, txn, service_id):
+    def _get_last_txn(self, txn, service_id: Optional[str]) -> int:
         txn.execute(
             "SELECT last_txn FROM application_services_state WHERE as_id=?",
             (service_id,),
@@ -324,7 +344,7 @@ class ApplicationServiceTransactionWorkerStore(
         else:
             return int(last_txn_id[0])  # select 'last_txn' col
 
-    async def set_appservice_last_pos(self, pos) -> None:
+    async def set_appservice_last_pos(self, pos: int) -> None:
         def set_appservice_last_pos_txn(txn):
             txn.execute(
                 "UPDATE appservice_stream_position SET stream_ordering = ?", (pos,)
@@ -334,7 +354,9 @@ class ApplicationServiceTransactionWorkerStore(
             "set_appservice_last_pos", set_appservice_last_pos_txn
         )
 
-    async def get_new_events_for_appservice(self, current_id, limit):
+    async def get_new_events_for_appservice(
+        self, current_id: int, limit: int
+    ) -> Tuple[int, List[EventBase]]:
         """Get all new events for an appservice"""
 
         def get_new_events_for_appservice_txn(txn):
@@ -394,7 +416,7 @@ class ApplicationServiceTransactionWorkerStore(
         )
 
     async def set_type_stream_id_for_appservice(
-        self, service: ApplicationService, type: str, pos: int
+        self, service: ApplicationService, type: str, pos: Optional[int]
     ) -> None:
         if type not in ("read_receipt", "presence"):
             raise ValueError(