summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/room.py42
-rw-r--r--synapse/handlers/saml_handler.py28
2 files changed, 34 insertions, 36 deletions
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index da12df7f53..73f9eeb399 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -25,8 +25,6 @@ from collections import OrderedDict
 
 from six import iteritems, string_types
 
-from twisted.internet import defer
-
 from synapse.api.constants import EventTypes, JoinRules, RoomCreationPreset
 from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, SynapseError
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
@@ -103,8 +101,7 @@ class RoomCreationHandler(BaseHandler):
 
         self.third_party_event_rules = hs.get_third_party_event_rules()
 
-    @defer.inlineCallbacks
-    def upgrade_room(
+    async def upgrade_room(
         self, requester: Requester, old_room_id: str, new_version: RoomVersion
     ):
         """Replace a room with a new room with a different version
@@ -117,7 +114,7 @@ class RoomCreationHandler(BaseHandler):
         Returns:
             Deferred[unicode]: the new room id
         """
-        yield self.ratelimit(requester)
+        await self.ratelimit(requester)
 
         user_id = requester.user.to_string()
 
@@ -138,7 +135,7 @@ class RoomCreationHandler(BaseHandler):
         # If this user has sent multiple upgrade requests for the same room
         # and one of them is not complete yet, cache the response and
         # return it to all subsequent requests
-        ret = yield self._upgrade_response_cache.wrap(
+        ret = await self._upgrade_response_cache.wrap(
             (old_room_id, user_id),
             self._upgrade_room,
             requester,
@@ -856,8 +853,7 @@ class RoomCreationHandler(BaseHandler):
         for (etype, state_key), content in initial_state.items():
             await send(etype=etype, state_key=state_key, content=content)
 
-    @defer.inlineCallbacks
-    def _generate_room_id(
+    async def _generate_room_id(
         self, creator_id: str, is_public: str, room_version: RoomVersion,
     ):
         # autogen room IDs and try to create it. We may clash, so just
@@ -869,7 +865,7 @@ class RoomCreationHandler(BaseHandler):
                 gen_room_id = RoomID(random_string, self.hs.hostname).to_string()
                 if isinstance(gen_room_id, bytes):
                     gen_room_id = gen_room_id.decode("utf-8")
-                yield self.store.store_room(
+                await self.store.store_room(
                     room_id=gen_room_id,
                     room_creator_user_id=creator_id,
                     is_public=is_public,
@@ -888,8 +884,7 @@ class RoomContextHandler(object):
         self.storage = hs.get_storage()
         self.state_store = self.storage.state
 
-    @defer.inlineCallbacks
-    def get_event_context(self, user, room_id, event_id, limit, event_filter):
+    async def get_event_context(self, user, room_id, event_id, limit, event_filter):
         """Retrieves events, pagination tokens and state around a given event
         in a room.
 
@@ -908,7 +903,7 @@ class RoomContextHandler(object):
         before_limit = math.floor(limit / 2.0)
         after_limit = limit - before_limit
 
-        users = yield self.store.get_users_in_room(room_id)
+        users = await self.store.get_users_in_room(room_id)
         is_peeking = user.to_string() not in users
 
         def filter_evts(events):
@@ -916,17 +911,17 @@ class RoomContextHandler(object):
                 self.storage, user.to_string(), events, is_peeking=is_peeking
             )
 
-        event = yield self.store.get_event(
+        event = await self.store.get_event(
             event_id, get_prev_content=True, allow_none=True
         )
         if not event:
             return None
 
-        filtered = yield (filter_evts([event]))
+        filtered = await filter_evts([event])
         if not filtered:
             raise AuthError(403, "You don't have permission to access that event.")
 
-        results = yield self.store.get_events_around(
+        results = await self.store.get_events_around(
             room_id, event_id, before_limit, after_limit, event_filter
         )
 
@@ -934,8 +929,8 @@ class RoomContextHandler(object):
             results["events_before"] = event_filter.filter(results["events_before"])
             results["events_after"] = event_filter.filter(results["events_after"])
 
-        results["events_before"] = yield filter_evts(results["events_before"])
-        results["events_after"] = yield filter_evts(results["events_after"])
+        results["events_before"] = await filter_evts(results["events_before"])
+        results["events_after"] = await filter_evts(results["events_after"])
         # filter_evts can return a pruned event in case the user is allowed to see that
         # there's something there but not see the content, so use the event that's in
         # `filtered` rather than the event we retrieved from the datastore.
@@ -962,7 +957,7 @@ class RoomContextHandler(object):
         # first? Shouldn't we be consistent with /sync?
         # https://github.com/matrix-org/matrix-doc/issues/687
 
-        state = yield self.state_store.get_state_for_events(
+        state = await self.state_store.get_state_for_events(
             [last_event_id], state_filter=state_filter
         )
 
@@ -970,7 +965,7 @@ class RoomContextHandler(object):
         if event_filter:
             state_events = event_filter.filter(state_events)
 
-        results["state"] = yield filter_evts(state_events)
+        results["state"] = await filter_evts(state_events)
 
         # We use a dummy token here as we only care about the room portion of
         # the token, which we replace.
@@ -989,13 +984,12 @@ class RoomEventSource(object):
     def __init__(self, hs):
         self.store = hs.get_datastore()
 
-    @defer.inlineCallbacks
-    def get_new_events(
+    async def get_new_events(
         self, user, from_key, limit, room_ids, is_guest, explicit_room_id=None
     ):
         # We just ignore the key for now.
 
-        to_key = yield self.get_current_key()
+        to_key = await self.get_current_key()
 
         from_token = RoomStreamToken.parse(from_key)
         if from_token.topological:
@@ -1008,11 +1002,11 @@ class RoomEventSource(object):
             # See https://github.com/matrix-org/matrix-doc/issues/1144
             raise NotImplementedError()
         else:
-            room_events = yield self.store.get_membership_changes_for_user(
+            room_events = await self.store.get_membership_changes_for_user(
                 user.to_string(), from_key, to_key
             )
 
-            room_to_events = yield self.store.get_room_events_stream_for_rooms(
+            room_to_events = await self.store.get_room_events_stream_for_rooms(
                 room_ids=room_ids,
                 from_key=from_key,
                 to_key=to_key,
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index 96f2dd36ad..e7015c704f 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 import logging
 import re
-from typing import Optional, Tuple
+from typing import Callable, Dict, Optional, Set, Tuple
 
 import attr
 import saml2
@@ -25,6 +25,7 @@ from synapse.api.errors import SynapseError
 from synapse.config import ConfigError
 from synapse.http.server import finish_request
 from synapse.http.servlet import parse_string
+from synapse.http.site import SynapseRequest
 from synapse.module_api import ModuleApi
 from synapse.module_api.errors import RedirectException
 from synapse.types import (
@@ -81,17 +82,19 @@ class SamlHandler:
 
         self._error_html_content = hs.config.saml2_error_html_content
 
-    def handle_redirect_request(self, client_redirect_url, ui_auth_session_id=None):
+    def handle_redirect_request(
+        self, client_redirect_url: bytes, ui_auth_session_id: Optional[str] = None
+    ) -> bytes:
         """Handle an incoming request to /login/sso/redirect
 
         Args:
-            client_redirect_url (bytes): the URL that we should redirect the
+            client_redirect_url: the URL that we should redirect the
                 client to when everything is done
-            ui_auth_session_id (Optional[str]): The session ID of the ongoing UI Auth (or
+            ui_auth_session_id: The session ID of the ongoing UI Auth (or
                 None if this is a login).
 
         Returns:
-            bytes: URL to redirect to
+            URL to redirect to
         """
         reqid, info = self._saml_client.prepare_for_authenticate(
             relay_state=client_redirect_url
@@ -109,15 +112,15 @@ class SamlHandler:
         # this shouldn't happen!
         raise Exception("prepare_for_authenticate didn't return a Location header")
 
-    async def handle_saml_response(self, request):
+    async def handle_saml_response(self, request: SynapseRequest) -> None:
         """Handle an incoming request to /_matrix/saml2/authn_response
 
         Args:
-            request (SynapseRequest): the incoming request from the browser. We'll
+            request: the incoming request from the browser. We'll
                 respond to it with a redirect.
 
         Returns:
-            Deferred[none]: Completes once we have handled the request.
+            Completes once we have handled the request.
         """
         resp_bytes = parse_string(request, "SAMLResponse", required=True)
         relay_state = parse_string(request, "RelayState", required=True)
@@ -310,6 +313,7 @@ DOT_REPLACE_PATTERN = re.compile(
 
 
 def dot_replace_for_mxid(username: str) -> str:
+    """Replace any characters which are not allowed in Matrix IDs with a dot."""
     username = username.lower()
     username = DOT_REPLACE_PATTERN.sub(".", username)
 
@@ -321,7 +325,7 @@ def dot_replace_for_mxid(username: str) -> str:
 MXID_MAPPER_MAP = {
     "hexencode": map_username_to_mxid_localpart,
     "dotreplace": dot_replace_for_mxid,
-}
+}  # type: Dict[str, Callable[[str], str]]
 
 
 @attr.s
@@ -349,7 +353,7 @@ class DefaultSamlMappingProvider(object):
 
     def get_remote_user_id(
         self, saml_response: saml2.response.AuthnResponse, client_redirect_url: str
-    ):
+    ) -> str:
         """Extracts the remote user id from the SAML response"""
         try:
             return saml_response.ava["uid"][0]
@@ -428,14 +432,14 @@ class DefaultSamlMappingProvider(object):
         return SamlConfig(mxid_source_attribute, mxid_mapper)
 
     @staticmethod
-    def get_saml_attributes(config: SamlConfig) -> Tuple[set, set]:
+    def get_saml_attributes(config: SamlConfig) -> Tuple[Set[str], Set[str]]:
         """Returns the required attributes of a SAML
 
         Args:
             config: A SamlConfig object containing configuration params for this provider
 
         Returns:
-            tuple[set,set]: The first set equates to the saml auth response
+            The first set equates to the saml auth response
                 attributes that are required for the module to function, whereas the
                 second set consists of those attributes which can be used if
                 available, but are not necessary