summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2021-10-18 15:01:10 -0400
committerGitHub <noreply@github.com>2021-10-18 15:01:10 -0400
commit3ab55d43bd66b377c1ed94a40931eba98dd07b01 (patch)
tree44ebe940127f32bde6e497211f5baf7647442117 /synapse
parentCheck auth on received events' auth_events (#11001) (diff)
downloadsynapse-3ab55d43bd66b377c1ed94a40931eba98dd07b01.tar.xz
Add missing type hints to synapse.api. (#11109)
* Convert UserPresenceState to attrs.
* Remove args/kwargs from error classes and explicitly pass msg/errorcode.
Diffstat (limited to '')
-rw-r--r--synapse/api/auth.py14
-rw-r--r--synapse/api/errors.py69
-rw-r--r--synapse/api/filtering.py18
-rw-r--r--synapse/api/presence.py51
-rw-r--r--synapse/api/ratelimiting.py4
-rw-r--r--synapse/api/urls.py13
-rw-r--r--synapse/handlers/presence.py2
-rw-r--r--synapse/storage/databases/main/registration.py8
8 files changed, 80 insertions, 99 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index e6ca9232ee..44883c6663 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -245,7 +245,7 @@ class Auth:
 
     async def validate_appservice_can_control_user_id(
         self, app_service: ApplicationService, user_id: str
-    ):
+    ) -> None:
         """Validates that the app service is allowed to control
         the given user.
 
@@ -618,5 +618,13 @@ class Auth:
                 % (user_id, room_id),
             )
 
-    async def check_auth_blocking(self, *args, **kwargs) -> None:
-        await self._auth_blocking.check_auth_blocking(*args, **kwargs)
+    async def check_auth_blocking(
+        self,
+        user_id: Optional[str] = None,
+        threepid: Optional[dict] = None,
+        user_type: Optional[str] = None,
+        requester: Optional[Requester] = None,
+    ) -> None:
+        await self._auth_blocking.check_auth_blocking(
+            user_id=user_id, threepid=threepid, user_type=user_type, requester=requester
+        )
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index 9480f448d7..685d1c25cf 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -18,7 +18,7 @@
 import logging
 import typing
 from http import HTTPStatus
-from typing import Dict, List, Optional, Union
+from typing import Any, Dict, List, Optional, Union
 
 from twisted.web import http
 
@@ -143,7 +143,7 @@ class SynapseError(CodeMessageException):
         super().__init__(code, msg)
         self.errcode = errcode
 
-    def error_dict(self):
+    def error_dict(self) -> "JsonDict":
         return cs_error(self.msg, self.errcode)
 
 
@@ -175,7 +175,7 @@ class ProxiedRequestError(SynapseError):
         else:
             self._additional_fields = dict(additional_fields)
 
-    def error_dict(self):
+    def error_dict(self) -> "JsonDict":
         return cs_error(self.msg, self.errcode, **self._additional_fields)
 
 
@@ -196,7 +196,7 @@ class ConsentNotGivenError(SynapseError):
         )
         self._consent_uri = consent_uri
 
-    def error_dict(self):
+    def error_dict(self) -> "JsonDict":
         return cs_error(self.msg, self.errcode, consent_uri=self._consent_uri)
 
 
@@ -262,14 +262,10 @@ class InteractiveAuthIncompleteError(Exception):
 class UnrecognizedRequestError(SynapseError):
     """An error indicating we don't understand the request you're trying to make"""
 
-    def __init__(self, *args, **kwargs):
-        if "errcode" not in kwargs:
-            kwargs["errcode"] = Codes.UNRECOGNIZED
-        if len(args) == 0:
-            message = "Unrecognized request"
-        else:
-            message = args[0]
-        super().__init__(400, message, **kwargs)
+    def __init__(
+        self, msg: str = "Unrecognized request", errcode: str = Codes.UNRECOGNIZED
+    ):
+        super().__init__(400, msg, errcode)
 
 
 class NotFoundError(SynapseError):
@@ -284,10 +280,8 @@ class AuthError(SynapseError):
     other poorly-defined times.
     """
 
-    def __init__(self, *args, **kwargs):
-        if "errcode" not in kwargs:
-            kwargs["errcode"] = Codes.FORBIDDEN
-        super().__init__(*args, **kwargs)
+    def __init__(self, code: int, msg: str, errcode: str = Codes.FORBIDDEN):
+        super().__init__(code, msg, errcode)
 
 
 class InvalidClientCredentialsError(SynapseError):
@@ -321,7 +315,7 @@ class InvalidClientTokenError(InvalidClientCredentialsError):
         super().__init__(msg=msg, errcode="M_UNKNOWN_TOKEN")
         self._soft_logout = soft_logout
 
-    def error_dict(self):
+    def error_dict(self) -> "JsonDict":
         d = super().error_dict()
         d["soft_logout"] = self._soft_logout
         return d
@@ -345,7 +339,7 @@ class ResourceLimitError(SynapseError):
         self.limit_type = limit_type
         super().__init__(code, msg, errcode=errcode)
 
-    def error_dict(self):
+    def error_dict(self) -> "JsonDict":
         return cs_error(
             self.msg,
             self.errcode,
@@ -357,32 +351,17 @@ class ResourceLimitError(SynapseError):
 class EventSizeError(SynapseError):
     """An error raised when an event is too big."""
 
-    def __init__(self, *args, **kwargs):
-        if "errcode" not in kwargs:
-            kwargs["errcode"] = Codes.TOO_LARGE
-        super().__init__(413, *args, **kwargs)
-
-
-class EventStreamError(SynapseError):
-    """An error raised when there a problem with the event stream."""
-
-    def __init__(self, *args, **kwargs):
-        if "errcode" not in kwargs:
-            kwargs["errcode"] = Codes.BAD_PAGINATION
-        super().__init__(*args, **kwargs)
+    def __init__(self, msg: str):
+        super().__init__(413, msg, Codes.TOO_LARGE)
 
 
 class LoginError(SynapseError):
     """An error raised when there was a problem logging in."""
 
-    pass
-
 
 class StoreError(SynapseError):
     """An error raised when there was a problem storing some data."""
 
-    pass
-
 
 class InvalidCaptchaError(SynapseError):
     def __init__(
@@ -395,7 +374,7 @@ class InvalidCaptchaError(SynapseError):
         super().__init__(code, msg, errcode)
         self.error_url = error_url
 
-    def error_dict(self):
+    def error_dict(self) -> "JsonDict":
         return cs_error(self.msg, self.errcode, error_url=self.error_url)
 
 
@@ -412,7 +391,7 @@ class LimitExceededError(SynapseError):
         super().__init__(code, msg, errcode)
         self.retry_after_ms = retry_after_ms
 
-    def error_dict(self):
+    def error_dict(self) -> "JsonDict":
         return cs_error(self.msg, self.errcode, retry_after_ms=self.retry_after_ms)
 
 
@@ -443,10 +422,8 @@ class UnsupportedRoomVersionError(SynapseError):
 class ThreepidValidationError(SynapseError):
     """An error raised when there was a problem authorising an event."""
 
-    def __init__(self, *args, **kwargs):
-        if "errcode" not in kwargs:
-            kwargs["errcode"] = Codes.FORBIDDEN
-        super().__init__(*args, **kwargs)
+    def __init__(self, msg: str, errcode: str = Codes.FORBIDDEN):
+        super().__init__(400, msg, errcode)
 
 
 class IncompatibleRoomVersionError(SynapseError):
@@ -466,7 +443,7 @@ class IncompatibleRoomVersionError(SynapseError):
 
         self._room_version = room_version
 
-    def error_dict(self):
+    def error_dict(self) -> "JsonDict":
         return cs_error(self.msg, self.errcode, room_version=self._room_version)
 
 
@@ -494,7 +471,7 @@ class RequestSendFailed(RuntimeError):
     errors (like programming errors).
     """
 
-    def __init__(self, inner_exception, can_retry):
+    def __init__(self, inner_exception: BaseException, can_retry: bool):
         super().__init__(
             "Failed to send request: %s: %s"
             % (type(inner_exception).__name__, inner_exception)
@@ -503,7 +480,7 @@ class RequestSendFailed(RuntimeError):
         self.can_retry = can_retry
 
 
-def cs_error(msg: str, code: str = Codes.UNKNOWN, **kwargs):
+def cs_error(msg: str, code: str = Codes.UNKNOWN, **kwargs: Any) -> "JsonDict":
     """Utility method for constructing an error response for client-server
     interactions.
 
@@ -551,7 +528,7 @@ class FederationError(RuntimeError):
         msg = "%s %s: %s" % (level, code, reason)
         super().__init__(msg)
 
-    def get_dict(self):
+    def get_dict(self) -> "JsonDict":
         return {
             "level": self.level,
             "code": self.code,
@@ -580,7 +557,7 @@ class HttpResponseException(CodeMessageException):
         super().__init__(code, msg)
         self.response = response
 
-    def to_synapse_error(self):
+    def to_synapse_error(self) -> SynapseError:
         """Make a SynapseError based on an HTTPResponseException
 
         This is useful when a proxied request has failed, and we need to
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index 20e91a115d..bc550ae646 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -231,24 +231,24 @@ class FilterCollection:
     def include_redundant_members(self) -> bool:
         return self._room_state_filter.include_redundant_members()
 
-    def filter_presence(self, events):
+    def filter_presence(
+        self, events: Iterable[UserPresenceState]
+    ) -> List[UserPresenceState]:
         return self._presence_filter.filter(events)
 
-    def filter_account_data(self, events):
+    def filter_account_data(self, events: Iterable[JsonDict]) -> List[JsonDict]:
         return self._account_data.filter(events)
 
-    def filter_room_state(self, events):
+    def filter_room_state(self, events: Iterable[EventBase]) -> List[EventBase]:
         return self._room_state_filter.filter(self._room_filter.filter(events))
 
-    def filter_room_timeline(self, events: Iterable[FilterEvent]) -> List[FilterEvent]:
+    def filter_room_timeline(self, events: Iterable[EventBase]) -> List[EventBase]:
         return self._room_timeline_filter.filter(self._room_filter.filter(events))
 
-    def filter_room_ephemeral(self, events: Iterable[FilterEvent]) -> List[FilterEvent]:
+    def filter_room_ephemeral(self, events: Iterable[JsonDict]) -> List[JsonDict]:
         return self._room_ephemeral_filter.filter(self._room_filter.filter(events))
 
-    def filter_room_account_data(
-        self, events: Iterable[FilterEvent]
-    ) -> List[FilterEvent]:
+    def filter_room_account_data(self, events: Iterable[JsonDict]) -> List[JsonDict]:
         return self._room_account_data.filter(self._room_filter.filter(events))
 
     def blocks_all_presence(self) -> bool:
@@ -309,7 +309,7 @@ class Filter:
         # except for presence which actually gets passed around as its own
         # namedtuple type.
         if isinstance(event, UserPresenceState):
-            sender = event.user_id
+            sender: Optional[str] = event.user_id
             room_id = None
             ev_type = "m.presence"
             contains_url = False
diff --git a/synapse/api/presence.py b/synapse/api/presence.py
index a3bf0348d1..b80aa83cb3 100644
--- a/synapse/api/presence.py
+++ b/synapse/api/presence.py
@@ -12,49 +12,48 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from collections import namedtuple
+from typing import Any, Optional
+
+import attr
 
 from synapse.api.constants import PresenceState
+from synapse.types import JsonDict
 
 
-class UserPresenceState(
-    namedtuple(
-        "UserPresenceState",
-        (
-            "user_id",
-            "state",
-            "last_active_ts",
-            "last_federation_update_ts",
-            "last_user_sync_ts",
-            "status_msg",
-            "currently_active",
-        ),
-    )
-):
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class UserPresenceState:
     """Represents the current presence state of the user.
 
-    user_id (str)
-    last_active (int): Time in msec that the user last interacted with server.
-    last_federation_update (int): Time in msec since either a) we sent a presence
+    user_id
+    last_active: Time in msec that the user last interacted with server.
+    last_federation_update: Time in msec since either a) we sent a presence
         update to other servers or b) we received a presence update, depending
         on if is a local user or not.
-    last_user_sync (int): Time in msec that the user last *completed* a sync
+    last_user_sync: Time in msec that the user last *completed* a sync
         (or event stream).
-    status_msg (str): User set status message.
+    status_msg: User set status message.
     """
 
-    def as_dict(self):
-        return dict(self._asdict())
+    user_id: str
+    state: str
+    last_active_ts: int
+    last_federation_update_ts: int
+    last_user_sync_ts: int
+    status_msg: Optional[str]
+    currently_active: bool
+
+    def as_dict(self) -> JsonDict:
+        return attr.asdict(self)
 
     @staticmethod
-    def from_dict(d):
+    def from_dict(d: JsonDict) -> "UserPresenceState":
         return UserPresenceState(**d)
 
-    def copy_and_replace(self, **kwargs):
-        return self._replace(**kwargs)
+    def copy_and_replace(self, **kwargs: Any) -> "UserPresenceState":
+        return attr.evolve(self, **kwargs)
 
     @classmethod
-    def default(cls, user_id):
+    def default(cls, user_id: str) -> "UserPresenceState":
         """Returns a default presence state."""
         return cls(
             user_id=user_id,
diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py
index e8964097d3..849c18ceda 100644
--- a/synapse/api/ratelimiting.py
+++ b/synapse/api/ratelimiting.py
@@ -161,7 +161,7 @@ class Ratelimiter:
 
         return allowed, time_allowed
 
-    def _prune_message_counts(self, time_now_s: float):
+    def _prune_message_counts(self, time_now_s: float) -> None:
         """Remove message count entries that have not exceeded their defined
         rate_hz limit
 
@@ -190,7 +190,7 @@ class Ratelimiter:
         update: bool = True,
         n_actions: int = 1,
         _time_now_s: Optional[float] = None,
-    ):
+    ) -> None:
         """Checks if an action can be performed. If not, raises a LimitExceededError
 
         Checks if the user has ratelimiting disabled in the database by looking
diff --git a/synapse/api/urls.py b/synapse/api/urls.py
index 032c69b210..6e84b1524f 100644
--- a/synapse/api/urls.py
+++ b/synapse/api/urls.py
@@ -19,6 +19,7 @@ from hashlib import sha256
 from urllib.parse import urlencode
 
 from synapse.config import ConfigError
+from synapse.config.homeserver import HomeServerConfig
 
 SYNAPSE_CLIENT_API_PREFIX = "/_synapse/client"
 CLIENT_API_PREFIX = "/_matrix/client"
@@ -34,11 +35,7 @@ LEGACY_MEDIA_PREFIX = "/_matrix/media/v1"
 
 
 class ConsentURIBuilder:
-    def __init__(self, hs_config):
-        """
-        Args:
-            hs_config (synapse.config.homeserver.HomeServerConfig):
-        """
+    def __init__(self, hs_config: HomeServerConfig):
         if hs_config.key.form_secret is None:
             raise ConfigError("form_secret not set in config")
         if hs_config.server.public_baseurl is None:
@@ -47,15 +44,15 @@ class ConsentURIBuilder:
         self._hmac_secret = hs_config.key.form_secret.encode("utf-8")
         self._public_baseurl = hs_config.server.public_baseurl
 
-    def build_user_consent_uri(self, user_id):
+    def build_user_consent_uri(self, user_id: str) -> str:
         """Build a URI which we can give to the user to do their privacy
         policy consent
 
         Args:
-            user_id (str): mxid or username of user
+            user_id: mxid or username of user
 
         Returns
-            (str) the URI where the user can do consent
+            The URI where the user can do consent
         """
         mac = hmac.new(
             key=self._hmac_secret, msg=user_id.encode("ascii"), digestmod=sha256
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 404afb9402..b5968e047b 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -1489,7 +1489,7 @@ def format_user_presence_state(
     The "user_id" is optional so that this function can be used to format presence
     updates for client /sync responses and for federation /send requests.
     """
-    content = {"presence": state.state}
+    content: JsonDict = {"presence": state.state}
     if include_user_id:
         content["user_id"] = state.user_id
     if state.last_active_ts:
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 181841ee06..0ab56d8a07 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -2237,7 +2237,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
                     # accident.
                     row = {"client_secret": None, "validated_at": None}
                 else:
-                    raise ThreepidValidationError(400, "Unknown session_id")
+                    raise ThreepidValidationError("Unknown session_id")
 
             retrieved_client_secret = row["client_secret"]
             validated_at = row["validated_at"]
@@ -2252,14 +2252,14 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
 
             if not row:
                 raise ThreepidValidationError(
-                    400, "Validation token not found or has expired"
+                    "Validation token not found or has expired"
                 )
             expires = row["expires"]
             next_link = row["next_link"]
 
             if retrieved_client_secret != client_secret:
                 raise ThreepidValidationError(
-                    400, "This client_secret does not match the provided session_id"
+                    "This client_secret does not match the provided session_id"
                 )
 
             # If the session is already validated, no need to revalidate
@@ -2268,7 +2268,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
 
             if expires <= current_ts:
                 raise ThreepidValidationError(
-                    400, "This token has expired. Please request a new one"
+                    "This token has expired. Please request a new one"
                 )
 
             # Looks good. Validate the session