diff options
Diffstat (limited to 'synapse/handlers')
-rw-r--r-- | synapse/handlers/appservice.py | 75 | ||||
-rw-r--r-- | synapse/handlers/auth.py | 23 |
2 files changed, 57 insertions, 41 deletions
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 3ed29a2c16..9fc8444228 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -12,9 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging -from typing import Dict, List, Optional, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Union from prometheus_client import Counter @@ -34,16 +33,20 @@ from synapse.metrics.background_process_metrics import ( run_as_background_process, wrap_as_background_process, ) -from synapse.types import Collection, JsonDict, RoomStreamToken, UserID +from synapse.storage.databases.main.directory import RoomAliasMapping +from synapse.types import Collection, JsonDict, RoomAlias, RoomStreamToken, UserID from synapse.util.metrics import Measure +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) events_processed_counter = Counter("synapse_handlers_appservice_events_processed", "") class ApplicationServicesHandler: - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.is_mine_id = hs.is_mine_id self.appservice_api = hs.get_application_service_api() @@ -247,7 +250,9 @@ class ApplicationServicesHandler: service, "presence", new_token ) - async def _handle_typing(self, service: ApplicationService, new_token: int): + async def _handle_typing( + self, service: ApplicationService, new_token: int + ) -> List[JsonDict]: typing_source = self.event_sources.sources["typing"] # Get the typing events from just before current typing, _ = await typing_source.get_new_events_as( @@ -259,7 +264,7 @@ class ApplicationServicesHandler: ) return typing - async def _handle_receipts(self, service: ApplicationService): + async def _handle_receipts(self, service: ApplicationService) -> List[JsonDict]: from_key = await self.store.get_type_stream_id_for_appservice( service, "read_receipt" ) @@ -271,7 +276,7 @@ class ApplicationServicesHandler: async def _handle_presence( self, service: ApplicationService, users: Collection[Union[str, UserID]] - ): + ) -> List[JsonDict]: events = [] # type: List[JsonDict] presence_source = self.event_sources.sources["presence"] from_key = await self.store.get_type_stream_id_for_appservice( @@ -301,11 +306,11 @@ class ApplicationServicesHandler: return events - async def query_user_exists(self, user_id): + async def query_user_exists(self, user_id: str) -> bool: """Check if any application service knows this user_id exists. Args: - user_id(str): The user to query if they exist on any AS. + user_id: The user to query if they exist on any AS. Returns: True if this user exists on at least one application service. """ @@ -316,11 +321,13 @@ class ApplicationServicesHandler: return True return False - async def query_room_alias_exists(self, room_alias): + async def query_room_alias_exists( + self, room_alias: RoomAlias + ) -> Optional[RoomAliasMapping]: """Check if an application service knows this room alias exists. Args: - room_alias(RoomAlias): The room alias to query. + room_alias: The room alias to query. Returns: namedtuple: with keys "room_id" and "servers" or None if no association can be found. @@ -336,10 +343,13 @@ class ApplicationServicesHandler: ) if is_known_alias: # the alias exists now so don't query more ASes. - result = await self.store.get_association_from_room_alias(room_alias) - return result + return await self.store.get_association_from_room_alias(room_alias) + + return None - async def query_3pe(self, kind, protocol, fields): + async def query_3pe( + self, kind: str, protocol: str, fields: Dict[bytes, List[bytes]] + ) -> List[JsonDict]: services = self._get_services_for_3pn(protocol) results = await make_deferred_yieldable( @@ -361,7 +371,9 @@ class ApplicationServicesHandler: return ret - async def get_3pe_protocols(self, only_protocol=None): + async def get_3pe_protocols( + self, only_protocol: Optional[str] = None + ) -> Dict[str, JsonDict]: services = self.store.get_app_services() protocols = {} # type: Dict[str, List[JsonDict]] @@ -379,7 +391,7 @@ class ApplicationServicesHandler: if info is not None: protocols[p].append(info) - def _merge_instances(infos): + def _merge_instances(infos: List[JsonDict]) -> JsonDict: if not infos: return {} @@ -394,19 +406,17 @@ class ApplicationServicesHandler: return combined - for p in protocols.keys(): - protocols[p] = _merge_instances(protocols[p]) + return {p: _merge_instances(protocols[p]) for p in protocols.keys()} - return protocols - - async def _get_services_for_event(self, event): + async def _get_services_for_event( + self, event: EventBase + ) -> List[ApplicationService]: """Retrieve a list of application services interested in this event. Args: - event(Event): The event to check. Can be None if alias_list is not. + event: The event to check. Can be None if alias_list is not. Returns: - list<ApplicationService>: A list of services interested in this - event based on the service regex. + A list of services interested in this event based on the service regex. """ services = self.store.get_app_services() @@ -420,17 +430,15 @@ class ApplicationServicesHandler: return interested_list - def _get_services_for_user(self, user_id): + def _get_services_for_user(self, user_id: str) -> List[ApplicationService]: services = self.store.get_app_services() - interested_list = [s for s in services if (s.is_interested_in_user(user_id))] - return interested_list + return [s for s in services if (s.is_interested_in_user(user_id))] - def _get_services_for_3pn(self, protocol): + def _get_services_for_3pn(self, protocol: str) -> List[ApplicationService]: services = self.store.get_app_services() - interested_list = [s for s in services if s.is_interested_in_protocol(protocol)] - return interested_list + return [s for s in services if s.is_interested_in_protocol(protocol)] - async def _is_unknown_user(self, user_id): + async def _is_unknown_user(self, user_id: str) -> bool: if not self.is_mine_id(user_id): # we don't know if they are unknown or not since it isn't one of our # users. We can't poke ASes. @@ -445,9 +453,8 @@ class ApplicationServicesHandler: service_list = [s for s in services if s.sender == user_id] return len(service_list) == 0 - async def _check_user_exists(self, user_id): + async def _check_user_exists(self, user_id: str) -> bool: unknown_user = await self._is_unknown_user(user_id) if unknown_user: - exists = await self.query_user_exists(user_id) - return exists + return await self.query_user_exists(user_id) return True diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index dd14ab69d7..276594f3d9 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -18,10 +18,20 @@ import logging import time import unicodedata import urllib.parse -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Iterable, + List, + Optional, + Tuple, + Union, +) import attr -import bcrypt # type: ignore[import] +import bcrypt import pymacaroons from synapse.api.constants import LoginType @@ -49,6 +59,9 @@ from synapse.util.threepids import canonicalise_email from ._base import BaseHandler +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) @@ -149,11 +162,7 @@ class SsoLoginExtraAttributes: class AuthHandler(BaseHandler): SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000 - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): - """ + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.checkers = {} # type: Dict[str, UserInteractiveAuthChecker] |