diff options
author | Amber Brown <hawkowl@atleastfornow.net> | 2019-06-20 19:32:02 +1000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-06-20 19:32:02 +1000 |
commit | 32e7c9e7f20b57dd081023ac42d6931a8da9b3a3 (patch) | |
tree | 139ef30c957535699d1ae0474e8b5ba2517196b2 /synapse/rest/client | |
parent | Merge pull request #5042 from matrix-org/erikj/fix_get_missing_events_error (diff) | |
download | synapse-32e7c9e7f20b57dd081023ac42d6931a8da9b3a3.tar.xz |
Run Black. (#5482)
Diffstat (limited to 'synapse/rest/client')
36 files changed, 688 insertions, 854 deletions
diff --git a/synapse/rest/client/transactions.py b/synapse/rest/client/transactions.py index 48c17f1b6d..36404b797d 100644 --- a/synapse/rest/client/transactions.py +++ b/synapse/rest/client/transactions.py @@ -26,7 +26,6 @@ CLEANUP_PERIOD_MS = 1000 * 60 * 30 # 30 mins class HttpTransactionCache(object): - def __init__(self, hs): self.hs = hs self.auth = self.hs.get_auth() @@ -53,7 +52,7 @@ class HttpTransactionCache(object): str: A transaction key """ token = self.auth.get_access_token_from_request(request) - return request.path.decode('utf8') + "/" + token + return request.path.decode("utf8") + "/" + token def fetch_or_execute_request(self, request, fn, *args, **kwargs): """A helper function for fetch_or_execute which extracts diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py index 0035182bb9..dd0d38ea5c 100644 --- a/synapse/rest/client/v1/directory.py +++ b/synapse/rest/client/v1/directory.py @@ -56,8 +56,9 @@ class ClientDirectoryServer(RestServlet): content = parse_json_object_from_request(request) if "room_id" not in content: - raise SynapseError(400, 'Missing params: ["room_id"]', - errcode=Codes.BAD_JSON) + raise SynapseError( + 400, 'Missing params: ["room_id"]', errcode=Codes.BAD_JSON + ) logger.debug("Got content: %s", content) logger.debug("Got room name: %s", room_alias.to_string()) @@ -89,13 +90,11 @@ class ClientDirectoryServer(RestServlet): try: service = yield self.auth.get_appservice_by_req(request) room_alias = RoomAlias.from_string(room_alias) - yield dir_handler.delete_appservice_association( - service, room_alias - ) + yield dir_handler.delete_appservice_association(service, room_alias) logger.info( "Application service at %s deleted alias %s", service.url, - room_alias.to_string() + room_alias.to_string(), ) defer.returnValue((200, {})) except AuthError: @@ -107,14 +106,10 @@ class ClientDirectoryServer(RestServlet): room_alias = RoomAlias.from_string(room_alias) - yield dir_handler.delete_association( - requester, room_alias - ) + yield dir_handler.delete_association(requester, room_alias) logger.info( - "User %s deleted alias %s", - user.to_string(), - room_alias.to_string() + "User %s deleted alias %s", user.to_string(), room_alias.to_string() ) defer.returnValue((200, {})) @@ -135,9 +130,9 @@ class ClientDirectoryListServer(RestServlet): if room is None: raise NotFoundError("Unknown room") - defer.returnValue((200, { - "visibility": "public" if room["is_public"] else "private" - })) + defer.returnValue( + (200, {"visibility": "public" if room["is_public"] else "private"}) + ) @defer.inlineCallbacks def on_PUT(self, request, room_id): @@ -147,7 +142,7 @@ class ClientDirectoryListServer(RestServlet): visibility = content.get("visibility", "public") yield self.handlers.directory_handler.edit_published_room_list( - requester, room_id, visibility, + requester, room_id, visibility ) defer.returnValue((200, {})) @@ -157,7 +152,7 @@ class ClientDirectoryListServer(RestServlet): requester = yield self.auth.get_user_by_req(request) yield self.handlers.directory_handler.edit_published_room_list( - requester, room_id, "private", + requester, room_id, "private" ) defer.returnValue((200, {})) @@ -191,7 +186,7 @@ class ClientAppserviceDirectoryListServer(RestServlet): ) yield self.handlers.directory_handler.edit_published_appservice_room_list( - requester.app_service.id, network_id, room_id, visibility, + requester.app_service.id, network_id, room_id, visibility ) defer.returnValue((200, {})) diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py index 84ca36270b..d6de2b7360 100644 --- a/synapse/rest/client/v1/events.py +++ b/synapse/rest/client/v1/events.py @@ -38,17 +38,14 @@ class EventStreamRestServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request): - requester = yield self.auth.get_user_by_req( - request, - allow_guest=True, - ) + requester = yield self.auth.get_user_by_req(request, allow_guest=True) is_guest = requester.is_guest room_id = None if is_guest: if b"room_id" not in request.args: raise SynapseError(400, "Guest users must specify room_id param") if b"room_id" in request.args: - room_id = request.args[b"room_id"][0].decode('ascii') + room_id = request.args[b"room_id"][0].decode("ascii") pagin_config = PaginationConfig.from_request(request) timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 3b60728628..4efb679a04 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -44,10 +44,7 @@ def login_submission_legacy_convert(submission): to a typed object. """ if "user" in submission: - submission["identifier"] = { - "type": "m.id.user", - "user": submission["user"], - } + submission["identifier"] = {"type": "m.id.user", "user": submission["user"]} del submission["user"] if "medium" in submission and "address" in submission: @@ -73,11 +70,7 @@ def login_id_thirdparty_from_phone(identifier): msisdn = phone_number_to_msisdn(identifier["country"], identifier["number"]) - return { - "type": "m.id.thirdparty", - "medium": "msisdn", - "address": msisdn, - } + return {"type": "m.id.thirdparty", "medium": "msisdn", "address": msisdn} class LoginRestServlet(RestServlet): @@ -120,9 +113,9 @@ class LoginRestServlet(RestServlet): # login flow types returned. flows.append({"type": LoginRestServlet.TOKEN_TYPE}) - flows.extend(( - {"type": t} for t in self.auth_handler.get_supported_login_types() - )) + flows.extend( + ({"type": t} for t in self.auth_handler.get_supported_login_types()) + ) return (200, {"flows": flows}) @@ -132,7 +125,8 @@ class LoginRestServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request): self._address_ratelimiter.ratelimit( - request.getClientIP(), time_now_s=self.hs.clock.time(), + request.getClientIP(), + time_now_s=self.hs.clock.time(), rate_hz=self.hs.config.rc_login_address.per_second, burst_count=self.hs.config.rc_login_address.burst_count, update=True, @@ -140,8 +134,9 @@ class LoginRestServlet(RestServlet): login_submission = parse_json_object_from_request(request) try: - if self.jwt_enabled and (login_submission["type"] == - LoginRestServlet.JWT_TYPE): + if self.jwt_enabled and ( + login_submission["type"] == LoginRestServlet.JWT_TYPE + ): result = yield self.do_jwt_login(login_submission) elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE: result = yield self.do_token_login(login_submission) @@ -170,10 +165,10 @@ class LoginRestServlet(RestServlet): # field) logger.info( "Got login request with identifier: %r, medium: %r, address: %r, user: %r", - login_submission.get('identifier'), - login_submission.get('medium'), - login_submission.get('address'), - login_submission.get('user'), + login_submission.get("identifier"), + login_submission.get("medium"), + login_submission.get("address"), + login_submission.get("user"), ) login_submission_legacy_convert(login_submission) @@ -190,13 +185,13 @@ class LoginRestServlet(RestServlet): # convert threepid identifiers to user IDs if identifier["type"] == "m.id.thirdparty": - address = identifier.get('address') - medium = identifier.get('medium') + address = identifier.get("address") + medium = identifier.get("medium") if medium is None or address is None: raise SynapseError(400, "Invalid thirdparty identifier") - if medium == 'email': + if medium == "email": # For emails, transform the address to lowercase. # We store all email addreses as lowercase in the DB. # (See add_threepid in synapse/handlers/auth.py) @@ -205,34 +200,28 @@ class LoginRestServlet(RestServlet): # Check for login providers that support 3pid login types canonical_user_id, callback_3pid = ( yield self.auth_handler.check_password_provider_3pid( - medium, - address, - login_submission["password"], + medium, address, login_submission["password"] ) ) if canonical_user_id: # Authentication through password provider and 3pid succeeded result = yield self._register_device_with_callback( - canonical_user_id, login_submission, callback_3pid, + canonical_user_id, login_submission, callback_3pid ) defer.returnValue(result) # No password providers were able to handle this 3pid # Check local store user_id = yield self.hs.get_datastore().get_user_id_by_threepid( - medium, address, + medium, address ) if not user_id: logger.warn( - "unknown 3pid identifier medium %s, address %r", - medium, address, + "unknown 3pid identifier medium %s, address %r", medium, address ) raise LoginError(403, "", errcode=Codes.FORBIDDEN) - identifier = { - "type": "m.id.user", - "user": user_id, - } + identifier = {"type": "m.id.user", "user": user_id} # by this point, the identifier should be an m.id.user: if it's anything # else, we haven't understood it. @@ -242,22 +231,16 @@ class LoginRestServlet(RestServlet): raise SynapseError(400, "User identifier is missing 'user' key") canonical_user_id, callback = yield self.auth_handler.validate_login( - identifier["user"], - login_submission, + identifier["user"], login_submission ) result = yield self._register_device_with_callback( - canonical_user_id, login_submission, callback, + canonical_user_id, login_submission, callback ) defer.returnValue(result) @defer.inlineCallbacks - def _register_device_with_callback( - self, - user_id, - login_submission, - callback=None, - ): + def _register_device_with_callback(self, user_id, login_submission, callback=None): """ Registers a device with a given user_id. Optionally run a callback function after registration has completed. @@ -273,7 +256,7 @@ class LoginRestServlet(RestServlet): device_id = login_submission.get("device_id") initial_display_name = login_submission.get("initial_device_display_name") device_id, access_token = yield self.registration_handler.register_device( - user_id, device_id, initial_display_name, + user_id, device_id, initial_display_name ) result = { @@ -290,7 +273,7 @@ class LoginRestServlet(RestServlet): @defer.inlineCallbacks def do_token_login(self, login_submission): - token = login_submission['token'] + token = login_submission["token"] auth_handler = self.auth_handler user_id = ( yield auth_handler.validate_short_term_login_token_and_get_user_id(token) @@ -299,7 +282,7 @@ class LoginRestServlet(RestServlet): device_id = login_submission.get("device_id") initial_display_name = login_submission.get("initial_device_display_name") device_id, access_token = yield self.registration_handler.register_device( - user_id, device_id, initial_display_name, + user_id, device_id, initial_display_name ) result = { @@ -316,15 +299,16 @@ class LoginRestServlet(RestServlet): token = login_submission.get("token", None) if token is None: raise LoginError( - 401, "Token field for JWT is missing", - errcode=Codes.UNAUTHORIZED + 401, "Token field for JWT is missing", errcode=Codes.UNAUTHORIZED ) import jwt from jwt.exceptions import InvalidTokenError try: - payload = jwt.decode(token, self.jwt_secret, algorithms=[self.jwt_algorithm]) + payload = jwt.decode( + token, self.jwt_secret, algorithms=[self.jwt_algorithm] + ) except jwt.ExpiredSignatureError: raise LoginError(401, "JWT expired", errcode=Codes.UNAUTHORIZED) except InvalidTokenError: @@ -342,7 +326,7 @@ class LoginRestServlet(RestServlet): device_id = login_submission.get("device_id") initial_display_name = login_submission.get("initial_device_display_name") device_id, access_token = yield self.registration_handler.register_device( - registered_user_id, device_id, initial_display_name, + registered_user_id, device_id, initial_display_name ) result = { @@ -358,7 +342,7 @@ class LoginRestServlet(RestServlet): device_id = login_submission.get("device_id") initial_display_name = login_submission.get("initial_device_display_name") device_id, access_token = yield self.registration_handler.register_device( - registered_user_id, device_id, initial_display_name, + registered_user_id, device_id, initial_display_name ) result = { @@ -375,21 +359,20 @@ class CasRedirectServlet(RestServlet): def __init__(self, hs): super(CasRedirectServlet, self).__init__() - self.cas_server_url = hs.config.cas_server_url.encode('ascii') - self.cas_service_url = hs.config.cas_service_url.encode('ascii') + self.cas_server_url = hs.config.cas_server_url.encode("ascii") + self.cas_service_url = hs.config.cas_service_url.encode("ascii") def on_GET(self, request): args = request.args if b"redirectUrl" not in args: return (400, "Redirect URL not specified for CAS auth") - client_redirect_url_param = urllib.parse.urlencode({ - b"redirectUrl": args[b"redirectUrl"][0] - }).encode('ascii') - hs_redirect_url = (self.cas_service_url + - b"/_matrix/client/r0/login/cas/ticket") - service_param = urllib.parse.urlencode({ - b"service": b"%s?%s" % (hs_redirect_url, client_redirect_url_param) - }).encode('ascii') + client_redirect_url_param = urllib.parse.urlencode( + {b"redirectUrl": args[b"redirectUrl"][0]} + ).encode("ascii") + hs_redirect_url = self.cas_service_url + b"/_matrix/client/r0/login/cas/ticket" + service_param = urllib.parse.urlencode( + {b"service": b"%s?%s" % (hs_redirect_url, client_redirect_url_param)} + ).encode("ascii") request.redirect(b"%s/login?%s" % (self.cas_server_url, service_param)) finish_request(request) @@ -411,7 +394,7 @@ class CasTicketServlet(RestServlet): uri = self.cas_server_url + "/proxyValidate" args = { "ticket": parse_string(request, "ticket", required=True), - "service": self.cas_service_url + "service": self.cas_service_url, } try: body = yield self._http_client.get_raw(uri, args) @@ -438,7 +421,7 @@ class CasTicketServlet(RestServlet): raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED) return self._sso_auth_handler.on_successful_auth( - user, request, client_redirect_url, + user, request, client_redirect_url ) def parse_cas_response(self, cas_response_body): @@ -448,7 +431,7 @@ class CasTicketServlet(RestServlet): root = ET.fromstring(cas_response_body) if not root.tag.endswith("serviceResponse"): raise Exception("root of CAS response is not serviceResponse") - success = (root[0].tag.endswith("authenticationSuccess")) + success = root[0].tag.endswith("authenticationSuccess") for child in root[0]: if child.tag.endswith("user"): user = child.text @@ -466,11 +449,11 @@ class CasTicketServlet(RestServlet): raise Exception("CAS response does not contain user") except Exception: logger.error("Error parsing CAS response", exc_info=1) - raise LoginError(401, "Invalid CAS response", - errcode=Codes.UNAUTHORIZED) + raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) if not success: - raise LoginError(401, "Unsuccessful CAS response", - errcode=Codes.UNAUTHORIZED) + raise LoginError( + 401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED + ) return user, attributes @@ -482,6 +465,7 @@ class SSOAuthHandler(object): Args: hs (synapse.server.HomeServer) """ + def __init__(self, hs): self._hostname = hs.hostname self._auth_handler = hs.get_auth_handler() @@ -490,8 +474,7 @@ class SSOAuthHandler(object): @defer.inlineCallbacks def on_successful_auth( - self, username, request, client_redirect_url, - user_display_name=None, + self, username, request, client_redirect_url, user_display_name=None ): """Called once the user has successfully authenticated with the SSO. diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py index b8064f261e..cd711be519 100644 --- a/synapse/rest/client/v1/logout.py +++ b/synapse/rest/client/v1/logout.py @@ -46,7 +46,8 @@ class LogoutRestServlet(RestServlet): yield self._auth_handler.delete_access_token(access_token) else: yield self._device_handler.delete_device( - requester.user.to_string(), requester.device_id) + requester.user.to_string(), requester.device_id + ) defer.returnValue((200, {})) diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py index e263da3cb7..3e87f0fdb3 100644 --- a/synapse/rest/client/v1/presence.py +++ b/synapse/rest/client/v1/presence.py @@ -47,7 +47,7 @@ class PresenceStatusRestServlet(RestServlet): if requester.user != user: allowed = yield self.presence_handler.is_visible( - observed_user=user, observer_user=requester.user, + observed_user=user, observer_user=requester.user ) if not allowed: diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py index e15d9d82a6..4d8ab1f47e 100644 --- a/synapse/rest/client/v1/profile.py +++ b/synapse/rest/client/v1/profile.py @@ -63,8 +63,7 @@ class ProfileDisplaynameRestServlet(RestServlet): except Exception: defer.returnValue((400, "Unable to parse name")) - yield self.profile_handler.set_displayname( - user, requester, new_name, is_admin) + yield self.profile_handler.set_displayname(user, requester, new_name, is_admin) defer.returnValue((200, {})) @@ -113,8 +112,7 @@ class ProfileAvatarURLRestServlet(RestServlet): except Exception: defer.returnValue((400, "Unable to parse name")) - yield self.profile_handler.set_avatar_url( - user, requester, new_name, is_admin) + yield self.profile_handler.set_avatar_url(user, requester, new_name, is_admin) defer.returnValue((200, {})) diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index 3d6326fe2f..e635efb420 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -21,7 +21,11 @@ from synapse.api.errors import ( SynapseError, UnrecognizedRequestError, ) -from synapse.http.servlet import RestServlet, parse_json_value_from_request, parse_string +from synapse.http.servlet import ( + RestServlet, + parse_json_value_from_request, + parse_string, +) from synapse.push.baserules import BASE_RULE_IDS from synapse.push.clientformat import format_push_rules_for_user from synapse.push.rulekinds import PRIORITY_CLASS_MAP @@ -32,7 +36,8 @@ from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundExc class PushRuleRestServlet(RestServlet): PATTERNS = client_patterns("/(?P<path>pushrules/.*)$", v1=True) SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = ( - "Unrecognised request: You probably wanted a trailing slash") + "Unrecognised request: You probably wanted a trailing slash" + ) def __init__(self, hs): super(PushRuleRestServlet, self).__init__() @@ -54,27 +59,25 @@ class PushRuleRestServlet(RestServlet): requester = yield self.auth.get_user_by_req(request) - if '/' in spec['rule_id'] or '\\' in spec['rule_id']: + if "/" in spec["rule_id"] or "\\" in spec["rule_id"]: raise SynapseError(400, "rule_id may not contain slashes") content = parse_json_value_from_request(request) user_id = requester.user.to_string() - if 'attr' in spec: + if "attr" in spec: yield self.set_rule_attr(user_id, spec, content) self.notify_user(user_id) defer.returnValue((200, {})) - if spec['rule_id'].startswith('.'): + if spec["rule_id"].startswith("."): # Rule ids starting with '.' are reserved for server default rules. raise SynapseError(400, "cannot add new rule_ids that start with '.'") try: (conditions, actions) = _rule_tuple_from_request_object( - spec['template'], - spec['rule_id'], - content, + spec["template"], spec["rule_id"], content ) except InvalidRuleException as e: raise SynapseError(400, str(e)) @@ -95,7 +98,7 @@ class PushRuleRestServlet(RestServlet): conditions=conditions, actions=actions, before=before, - after=after + after=after, ) self.notify_user(user_id) except InconsistentRuleException as e: @@ -118,9 +121,7 @@ class PushRuleRestServlet(RestServlet): namespaced_rule_id = _namespaced_rule_id_from_spec(spec) try: - yield self.store.delete_push_rule( - user_id, namespaced_rule_id - ) + yield self.store.delete_push_rule(user_id, namespaced_rule_id) self.notify_user(user_id) defer.returnValue((200, {})) except StoreError as e: @@ -149,10 +150,10 @@ class PushRuleRestServlet(RestServlet): PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR ) - if path[0] == '': + if path[0] == "": defer.returnValue((200, rules)) - elif path[0] == 'global': - result = _filter_ruleset_with_path(rules['global'], path[1:]) + elif path[0] == "global": + result = _filter_ruleset_with_path(rules["global"], path[1:]) defer.returnValue((200, result)) else: raise UnrecognizedRequestError() @@ -162,12 +163,10 @@ class PushRuleRestServlet(RestServlet): def notify_user(self, user_id): stream_id, _ = self.store.get_push_rules_stream_token() - self.notifier.on_new_event( - "push_rules_key", stream_id, users=[user_id] - ) + self.notifier.on_new_event("push_rules_key", stream_id, users=[user_id]) def set_rule_attr(self, user_id, spec, val): - if spec['attr'] == 'enabled': + if spec["attr"] == "enabled": if isinstance(val, dict) and "enabled" in val: val = val["enabled"] if not isinstance(val, bool): @@ -176,14 +175,12 @@ class PushRuleRestServlet(RestServlet): # bools directly, so let's not break them. raise SynapseError(400, "Value for 'enabled' must be boolean") namespaced_rule_id = _namespaced_rule_id_from_spec(spec) - return self.store.set_push_rule_enabled( - user_id, namespaced_rule_id, val - ) - elif spec['attr'] == 'actions': - actions = val.get('actions') + return self.store.set_push_rule_enabled(user_id, namespaced_rule_id, val) + elif spec["attr"] == "actions": + actions = val.get("actions") _check_actions(actions) namespaced_rule_id = _namespaced_rule_id_from_spec(spec) - rule_id = spec['rule_id'] + rule_id = spec["rule_id"] is_default_rule = rule_id.startswith(".") if is_default_rule: if namespaced_rule_id not in BASE_RULE_IDS: @@ -210,12 +207,12 @@ def _rule_spec_from_path(path): """ if len(path) < 2: raise UnrecognizedRequestError() - if path[0] != 'pushrules': + if path[0] != "pushrules": raise UnrecognizedRequestError() scope = path[1] path = path[2:] - if scope != 'global': + if scope != "global": raise UnrecognizedRequestError() if len(path) == 0: @@ -229,56 +226,40 @@ def _rule_spec_from_path(path): rule_id = path[0] - spec = { - 'scope': scope, - 'template': template, - 'rule_id': rule_id - } + spec = {"scope": scope, "template": template, "rule_id": rule_id} path = path[1:] if len(path) > 0 and len(path[0]) > 0: - spec['attr'] = path[0] + spec["attr"] = path[0] return spec def _rule_tuple_from_request_object(rule_template, rule_id, req_obj): - if rule_template in ['override', 'underride']: - if 'conditions' not in req_obj: + if rule_template in ["override", "underride"]: + if "conditions" not in req_obj: raise InvalidRuleException("Missing 'conditions'") - conditions = req_obj['conditions'] + conditions = req_obj["conditions"] for c in conditions: - if 'kind' not in c: + if "kind" not in c: raise InvalidRuleException("Condition without 'kind'") - elif rule_template == 'room': - conditions = [{ - 'kind': 'event_match', - 'key': 'room_id', - 'pattern': rule_id - }] - elif rule_template == 'sender': - conditions = [{ - 'kind': 'event_match', - 'key': 'user_id', - 'pattern': rule_id - }] - elif rule_template == 'content': - if 'pattern' not in req_obj: + elif rule_template == "room": + conditions = [{"kind": "event_match", "key": "room_id", "pattern": rule_id}] + elif rule_template == "sender": + conditions = [{"kind": "event_match", "key": "user_id", "pattern": rule_id}] + elif rule_template == "content": + if "pattern" not in req_obj: raise InvalidRuleException("Content rule missing 'pattern'") - pat = req_obj['pattern'] + pat = req_obj["pattern"] - conditions = [{ - 'kind': 'event_match', - 'key': 'content.body', - 'pattern': pat - }] + conditions = [{"kind": "event_match", "key": "content.body", "pattern": pat}] else: raise InvalidRuleException("Unknown rule template: %s" % (rule_template,)) - if 'actions' not in req_obj: + if "actions" not in req_obj: raise InvalidRuleException("No actions found") - actions = req_obj['actions'] + actions = req_obj["actions"] _check_actions(actions) @@ -290,9 +271,9 @@ def _check_actions(actions): raise InvalidRuleException("No actions found") for a in actions: - if a in ['notify', 'dont_notify', 'coalesce']: + if a in ["notify", "dont_notify", "coalesce"]: pass - elif isinstance(a, dict) and 'set_tweak' in a: + elif isinstance(a, dict) and "set_tweak" in a: pass else: raise InvalidRuleException("Unrecognised action") @@ -304,7 +285,7 @@ def _filter_ruleset_with_path(ruleset, path): PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR ) - if path[0] == '': + if path[0] == "": return ruleset template_kind = path[0] if template_kind not in ruleset: @@ -314,13 +295,13 @@ def _filter_ruleset_with_path(ruleset, path): raise UnrecognizedRequestError( PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR ) - if path[0] == '': + if path[0] == "": return ruleset[template_kind] rule_id = path[0] the_rule = None for r in ruleset[template_kind]: - if r['rule_id'] == rule_id: + if r["rule_id"] == rule_id: the_rule = r if the_rule is None: raise NotFoundError @@ -339,19 +320,19 @@ def _filter_ruleset_with_path(ruleset, path): def _priority_class_from_spec(spec): - if spec['template'] not in PRIORITY_CLASS_MAP.keys(): - raise InvalidRuleException("Unknown template: %s" % (spec['template'])) - pc = PRIORITY_CLASS_MAP[spec['template']] + if spec["template"] not in PRIORITY_CLASS_MAP.keys(): + raise InvalidRuleException("Unknown template: %s" % (spec["template"])) + pc = PRIORITY_CLASS_MAP[spec["template"]] return pc def _namespaced_rule_id_from_spec(spec): - return _namespaced_rule_id(spec, spec['rule_id']) + return _namespaced_rule_id(spec, spec["rule_id"]) def _namespaced_rule_id(spec, rule_id): - return "global/%s/%s" % (spec['template'], rule_id) + return "global/%s/%s" % (spec["template"], rule_id) class InvalidRuleException(Exception): diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py index 15d860db37..e9246018df 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py @@ -44,9 +44,7 @@ class PushersRestServlet(RestServlet): requester = yield self.auth.get_user_by_req(request) user = requester.user - pushers = yield self.hs.get_datastore().get_pushers_by_user_id( - user.to_string() - ) + pushers = yield self.hs.get_datastore().get_pushers_by_user_id(user.to_string()) allowed_keys = [ "app_display_name", @@ -87,50 +85,61 @@ class PushersSetRestServlet(RestServlet): content = parse_json_object_from_request(request) - if ('pushkey' in content and 'app_id' in content - and 'kind' in content and - content['kind'] is None): + if ( + "pushkey" in content + and "app_id" in content + and "kind" in content + and content["kind"] is None + ): yield self.pusher_pool.remove_pusher( - content['app_id'], content['pushkey'], user_id=user.to_string() + content["app_id"], content["pushkey"], user_id=user.to_string() ) defer.returnValue((200, {})) assert_params_in_dict( content, - ['kind', 'app_id', 'app_display_name', - 'device_display_name', 'pushkey', 'lang', 'data'] + [ + "kind", + "app_id", + "app_display_name", + "device_display_name", + "pushkey", + "lang", + "data", + ], ) - logger.debug("set pushkey %s to kind %s", content['pushkey'], content['kind']) + logger.debug("set pushkey %s to kind %s", content["pushkey"], content["kind"]) logger.debug("Got pushers request with body: %r", content) append = False - if 'append' in content: - append = content['append'] + if "append" in content: + append = content["append"] if not append: yield self.pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user( - app_id=content['app_id'], - pushkey=content['pushkey'], - not_user_id=user.to_string() + app_id=content["app_id"], + pushkey=content["pushkey"], + not_user_id=user.to_string(), ) try: yield self.pusher_pool.add_pusher( user_id=user.to_string(), access_token=requester.access_token_id, - kind=content['kind'], - app_id=content['app_id'], - app_display_name=content['app_display_name'], - device_display_name=content['device_display_name'], - pushkey=content['pushkey'], - lang=content['lang'], - data=content['data'], - profile_tag=content.get('profile_tag', ""), + kind=content["kind"], + app_id=content["app_id"], + app_display_name=content["app_display_name"], + device_display_name=content["device_display_name"], + pushkey=content["pushkey"], + lang=content["lang"], + data=content["data"], + profile_tag=content.get("profile_tag", ""), ) except PusherConfigException as pce: - raise SynapseError(400, "Config Error: " + str(pce), - errcode=Codes.MISSING_PARAM) + raise SynapseError( + 400, "Config Error: " + str(pce), errcode=Codes.MISSING_PARAM + ) self.notifier.on_new_replication_data() @@ -144,6 +153,7 @@ class PushersRemoveRestServlet(RestServlet): """ To allow pusher to be delete by clicking a link (ie. GET request) """ + PATTERNS = client_patterns("/pushers/remove$", v1=True) SUCCESS_HTML = b"<html><body>You have been unsubscribed</body><html>" @@ -164,9 +174,7 @@ class PushersRemoveRestServlet(RestServlet): try: yield self.pusher_pool.remove_pusher( - app_id=app_id, - pushkey=pushkey, - user_id=user.to_string(), + app_id=app_id, pushkey=pushkey, user_id=user.to_string() ) except StoreError as se: if se.code != 404: @@ -177,9 +185,9 @@ class PushersRemoveRestServlet(RestServlet): request.setResponseCode(200) request.setHeader(b"Content-Type", b"text/html; charset=utf-8") - request.setHeader(b"Content-Length", b"%d" % ( - len(PushersRemoveRestServlet.SUCCESS_HTML), - )) + request.setHeader( + b"Content-Length", b"%d" % (len(PushersRemoveRestServlet.SUCCESS_HTML),) + ) request.write(PushersRemoveRestServlet.SUCCESS_HTML) finish_request(request) defer.returnValue(None) diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index e8f672c4ba..a028337125 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -61,18 +61,16 @@ class RoomCreateRestServlet(TransactionRestServlet): PATTERNS = "/createRoom" register_txn_path(self, PATTERNS, http_server) # define CORS for all of /rooms in RoomCreateRestServlet for simplicity - http_server.register_paths("OPTIONS", - client_patterns("/rooms(?:/.*)?$", v1=True), - self.on_OPTIONS) + http_server.register_paths( + "OPTIONS", client_patterns("/rooms(?:/.*)?$", v1=True), self.on_OPTIONS + ) # define CORS for /createRoom[/txnid] - http_server.register_paths("OPTIONS", - client_patterns("/createRoom(?:/.*)?$", v1=True), - self.on_OPTIONS) + http_server.register_paths( + "OPTIONS", client_patterns("/createRoom(?:/.*)?$", v1=True), self.on_OPTIONS + ) def on_PUT(self, request, txn_id): - return self.txns.fetch_or_execute_request( - request, self.on_POST, request - ) + return self.txns.fetch_or_execute_request(request, self.on_POST, request) @defer.inlineCallbacks def on_POST(self, request): @@ -107,21 +105,23 @@ class RoomStateEventRestServlet(TransactionRestServlet): no_state_key = "/rooms/(?P<room_id>[^/]*)/state/(?P<event_type>[^/]*)$" # /room/$roomid/state/$eventtype/$statekey - state_key = ("/rooms/(?P<room_id>[^/]*)/state/" - "(?P<event_type>[^/]*)/(?P<state_key>[^/]*)$") - - http_server.register_paths("GET", - client_patterns(state_key, v1=True), - self.on_GET) - http_server.register_paths("PUT", - client_patterns(state_key, v1=True), - self.on_PUT) - http_server.register_paths("GET", - client_patterns(no_state_key, v1=True), - self.on_GET_no_state_key) - http_server.register_paths("PUT", - client_patterns(no_state_key, v1=True), - self.on_PUT_no_state_key) + state_key = ( + "/rooms/(?P<room_id>[^/]*)/state/" + "(?P<event_type>[^/]*)/(?P<state_key>[^/]*)$" + ) + + http_server.register_paths( + "GET", client_patterns(state_key, v1=True), self.on_GET + ) + http_server.register_paths( + "PUT", client_patterns(state_key, v1=True), self.on_PUT + ) + http_server.register_paths( + "GET", client_patterns(no_state_key, v1=True), self.on_GET_no_state_key + ) + http_server.register_paths( + "PUT", client_patterns(no_state_key, v1=True), self.on_PUT_no_state_key + ) def on_GET_no_state_key(self, request, room_id, event_type): return self.on_GET(request, room_id, event_type, "") @@ -132,8 +132,9 @@ class RoomStateEventRestServlet(TransactionRestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id, event_type, state_key): requester = yield self.auth.get_user_by_req(request, allow_guest=True) - format = parse_string(request, "format", default="content", - allowed_values=["content", "event"]) + format = parse_string( + request, "format", default="content", allowed_values=["content", "event"] + ) msg_handler = self.message_handler data = yield msg_handler.get_room_data( @@ -145,9 +146,7 @@ class RoomStateEventRestServlet(TransactionRestServlet): ) if not data: - raise SynapseError( - 404, "Event not found.", errcode=Codes.NOT_FOUND - ) + raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) if format == "event": event = format_event_for_client_v2(data.get_dict()) @@ -182,9 +181,7 @@ class RoomStateEventRestServlet(TransactionRestServlet): ) else: event = yield self.event_creation_handler.create_and_send_nonmember_event( - requester, - event_dict, - txn_id=txn_id, + requester, event_dict, txn_id=txn_id ) ret = {} @@ -195,7 +192,6 @@ class RoomStateEventRestServlet(TransactionRestServlet): # TODO: Needs unit testing for generic events + feedback class RoomSendEventRestServlet(TransactionRestServlet): - def __init__(self, hs): super(RoomSendEventRestServlet, self).__init__(hs) self.event_creation_handler = hs.get_event_creation_handler() @@ -203,7 +199,7 @@ class RoomSendEventRestServlet(TransactionRestServlet): def register(self, http_server): # /rooms/$roomid/send/$event_type[/$txn_id] - PATTERNS = ("/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)") + PATTERNS = "/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)" register_txn_path(self, PATTERNS, http_server, with_get=True) @defer.inlineCallbacks @@ -218,13 +214,11 @@ class RoomSendEventRestServlet(TransactionRestServlet): "sender": requester.user.to_string(), } - if b'ts' in request.args and requester.app_service: - event_dict['origin_server_ts'] = parse_integer(request, "ts", 0) + if b"ts" in request.args and requester.app_service: + event_dict["origin_server_ts"] = parse_integer(request, "ts", 0) event = yield self.event_creation_handler.create_and_send_nonmember_event( - requester, - event_dict, - txn_id=txn_id, + requester, event_dict, txn_id=txn_id ) defer.returnValue((200, {"event_id": event.event_id})) @@ -247,15 +241,12 @@ class JoinRoomAliasServlet(TransactionRestServlet): def register(self, http_server): # /join/$room_identifier[/$txn_id] - PATTERNS = ("/join/(?P<room_identifier>[^/]*)") + PATTERNS = "/join/(?P<room_identifier>[^/]*)" register_txn_path(self, PATTERNS, http_server) @defer.inlineCallbacks def on_POST(self, request, room_identifier, txn_id=None): - requester = yield self.auth.get_user_by_req( - request, - allow_guest=True, - ) + requester = yield self.auth.get_user_by_req(request, allow_guest=True) try: content = parse_json_object_from_request(request) @@ -268,7 +259,7 @@ class JoinRoomAliasServlet(TransactionRestServlet): room_id = room_identifier try: remote_room_hosts = [ - x.decode('ascii') for x in request.args[b"server_name"] + x.decode("ascii") for x in request.args[b"server_name"] ] except Exception: remote_room_hosts = None @@ -278,9 +269,9 @@ class JoinRoomAliasServlet(TransactionRestServlet): room_id, remote_room_hosts = yield handler.lookup_room_alias(room_alias) room_id = room_id.to_string() else: - raise SynapseError(400, "%s was not legal room ID or room alias" % ( - room_identifier, - )) + raise SynapseError( + 400, "%s was not legal room ID or room alias" % (room_identifier,) + ) yield self.room_member_handler.update_membership( requester=requester, @@ -339,14 +330,11 @@ class PublicRoomListRestServlet(TransactionRestServlet): handler = self.hs.get_room_list_handler() if server: data = yield handler.get_remote_public_room_list( - server, - limit=limit, - since_token=since_token, + server, limit=limit, since_token=since_token ) else: data = yield handler.get_local_public_room_list( - limit=limit, - since_token=since_token, + limit=limit, since_token=since_token ) defer.returnValue((200, data)) @@ -439,16 +427,13 @@ class RoomMemberListRestServlet(RestServlet): chunk = [] for event in events: - if ( - (membership and event['content'].get("membership") != membership) or - (not_membership and event['content'].get("membership") == not_membership) + if (membership and event["content"].get("membership") != membership) or ( + not_membership and event["content"].get("membership") == not_membership ): continue chunk.append(event) - defer.returnValue((200, { - "chunk": chunk - })) + defer.returnValue((200, {"chunk": chunk})) # deprecated in favour of /members?membership=join? @@ -466,12 +451,10 @@ class JoinedRoomMemberListRestServlet(RestServlet): requester = yield self.auth.get_user_by_req(request) users_with_profile = yield self.message_handler.get_joined_members( - requester, room_id, + requester, room_id ) - defer.returnValue((200, { - "joined": users_with_profile, - })) + defer.returnValue((200, {"joined": users_with_profile})) # TODO: Needs better unit testing @@ -486,9 +469,7 @@ class RoomMessageListRestServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id): requester = yield self.auth.get_user_by_req(request, allow_guest=True) - pagination_config = PaginationConfig.from_request( - request, default_limit=10, - ) + pagination_config = PaginationConfig.from_request(request, default_limit=10) as_client_event = b"raw" not in request.args filter_bytes = parse_string(request, b"filter", encoding=None) if filter_bytes: @@ -544,9 +525,7 @@ class RoomInitialSyncRestServlet(RestServlet): requester = yield self.auth.get_user_by_req(request, allow_guest=True) pagination_config = PaginationConfig.from_request(request) content = yield self.initial_sync_handler.room_initial_sync( - room_id=room_id, - requester=requester, - pagin_config=pagination_config, + room_id=room_id, requester=requester, pagin_config=pagination_config ) defer.returnValue((200, content)) @@ -603,30 +582,24 @@ class RoomEventContextServlet(RestServlet): event_filter = None results = yield self.room_context_handler.get_event_context( - requester.user, - room_id, - event_id, - limit, - event_filter, + requester.user, room_id, event_id, limit, event_filter ) if not results: - raise SynapseError( - 404, "Event not found.", errcode=Codes.NOT_FOUND - ) + raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) time_now = self.clock.time_msec() results["events_before"] = yield self._event_serializer.serialize_events( - results["events_before"], time_now, + results["events_before"], time_now ) results["event"] = yield self._event_serializer.serialize_event( - results["event"], time_now, + results["event"], time_now ) results["events_after"] = yield self._event_serializer.serialize_events( - results["events_after"], time_now, + results["events_after"], time_now ) results["state"] = yield self._event_serializer.serialize_events( - results["state"], time_now, + results["state"], time_now ) defer.returnValue((200, results)) @@ -639,20 +612,14 @@ class RoomForgetRestServlet(TransactionRestServlet): self.auth = hs.get_auth() def register(self, http_server): - PATTERNS = ("/rooms/(?P<room_id>[^/]*)/forget") + PATTERNS = "/rooms/(?P<room_id>[^/]*)/forget" register_txn_path(self, PATTERNS, http_server) @defer.inlineCallbacks def on_POST(self, request, room_id, txn_id=None): - requester = yield self.auth.get_user_by_req( - request, - allow_guest=False, - ) + requester = yield self.auth.get_user_by_req(request, allow_guest=False) - yield self.room_member_handler.forget( - user=requester.user, - room_id=room_id, - ) + yield self.room_member_handler.forget(user=requester.user, room_id=room_id) defer.returnValue((200, {})) @@ -664,7 +631,6 @@ class RoomForgetRestServlet(TransactionRestServlet): # TODO: Needs unit testing class RoomMembershipRestServlet(TransactionRestServlet): - def __init__(self, hs): super(RoomMembershipRestServlet, self).__init__(hs) self.room_member_handler = hs.get_room_member_handler() @@ -672,20 +638,19 @@ class RoomMembershipRestServlet(TransactionRestServlet): def register(self, http_server): # /rooms/$roomid/[invite|join|leave] - PATTERNS = ("/rooms/(?P<room_id>[^/]*)/" - "(?P<membership_action>join|invite|leave|ban|unban|kick)") + PATTERNS = ( + "/rooms/(?P<room_id>[^/]*)/" + "(?P<membership_action>join|invite|leave|ban|unban|kick)" + ) register_txn_path(self, PATTERNS, http_server) @defer.inlineCallbacks def on_POST(self, request, room_id, membership_action, txn_id=None): - requester = yield self.auth.get_user_by_req( - request, - allow_guest=True, - ) + requester = yield self.auth.get_user_by_req(request, allow_guest=True) if requester.is_guest and membership_action not in { Membership.JOIN, - Membership.LEAVE + Membership.LEAVE, }: raise AuthError(403, "Guest access not allowed") @@ -704,7 +669,7 @@ class RoomMembershipRestServlet(TransactionRestServlet): content["address"], content["id_server"], requester, - txn_id + txn_id, ) defer.returnValue((200, {})) return @@ -715,8 +680,8 @@ class RoomMembershipRestServlet(TransactionRestServlet): target = UserID.from_string(content["user_id"]) event_content = None - if 'reason' in content and membership_action in ['kick', 'ban']: - event_content = {'reason': content['reason']} + if "reason" in content and membership_action in ["kick", "ban"]: + event_content = {"reason": content["reason"]} yield self.room_member_handler.update_membership( requester=requester, @@ -755,7 +720,7 @@ class RoomRedactEventRestServlet(TransactionRestServlet): self.auth = hs.get_auth() def register(self, http_server): - PATTERNS = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)") + PATTERNS = "/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)" register_txn_path(self, PATTERNS, http_server) @defer.inlineCallbacks @@ -817,9 +782,7 @@ class RoomTypingRestServlet(RestServlet): ) else: yield self.typing_handler.stopped_typing( - target_user=target_user, - auth_user=requester.user, - room_id=room_id, + target_user=target_user, auth_user=requester.user, room_id=room_id ) defer.returnValue((200, {})) @@ -841,9 +804,7 @@ class SearchRestServlet(RestServlet): batch = parse_string(request, "next_batch") results = yield self.handlers.search_handler.search( - requester.user, - content, - batch, + requester.user, content, batch ) defer.returnValue((200, results)) @@ -879,20 +840,18 @@ def register_txn_path(servlet, regex_string, http_server, with_get=False): with_get: True to also register respective GET paths for the PUTs. """ http_server.register_paths( - "POST", - client_patterns(regex_string + "$", v1=True), - servlet.on_POST + "POST", client_patterns(regex_string + "$", v1=True), servlet.on_POST ) http_server.register_paths( "PUT", client_patterns(regex_string + "/(?P<txn_id>[^/]*)$", v1=True), - servlet.on_PUT + servlet.on_PUT, ) if with_get: http_server.register_paths( "GET", client_patterns(regex_string + "/(?P<txn_id>[^/]*)$", v1=True), - servlet.on_GET + servlet.on_GET, ) diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py index 6381049210..41b3171ac8 100644 --- a/synapse/rest/client/v1/voip.py +++ b/synapse/rest/client/v1/voip.py @@ -34,8 +34,7 @@ class VoipRestServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request): requester = yield self.auth.get_user_by_req( - request, - self.hs.config.turn_allow_guests + request, self.hs.config.turn_allow_guests ) turnUris = self.hs.config.turn_uris @@ -49,9 +48,7 @@ class VoipRestServlet(RestServlet): username = "%d:%s" % (expiry, requester.user.to_string()) mac = hmac.new( - turnSecret.encode(), - msg=username.encode(), - digestmod=hashlib.sha1 + turnSecret.encode(), msg=username.encode(), digestmod=hashlib.sha1 ) # We need to use standard padded base64 encoding here # encode_base64 because we need to add the standard padding to get the @@ -65,12 +62,17 @@ class VoipRestServlet(RestServlet): else: defer.returnValue((200, {})) - defer.returnValue((200, { - 'username': username, - 'password': password, - 'ttl': userLifetime / 1000, - 'uris': turnUris, - })) + defer.returnValue( + ( + 200, + { + "username": username, + "password": password, + "ttl": userLifetime / 1000, + "uris": turnUris, + }, + ) + ) def on_OPTIONS(self, request): return (200, {}) diff --git a/synapse/rest/client/v2_alpha/_base.py b/synapse/rest/client/v2_alpha/_base.py index 5236d5d566..e3d59ac3ac 100644 --- a/synapse/rest/client/v2_alpha/_base.py +++ b/synapse/rest/client/v2_alpha/_base.py @@ -52,11 +52,11 @@ def client_patterns(path_regex, releases=(0,), unstable=True, v1=False): def set_timeline_upper_limit(filter_json, filter_timeline_limit): if filter_timeline_limit < 0: return # no upper limits - timeline = filter_json.get('room', {}).get('timeline', {}) - if 'limit' in timeline: - filter_json['room']['timeline']["limit"] = min( - filter_json['room']['timeline']['limit'], - filter_timeline_limit) + timeline = filter_json.get("room", {}).get("timeline", {}) + if "limit" in timeline: + filter_json["room"]["timeline"]["limit"] = min( + filter_json["room"]["timeline"]["limit"], filter_timeline_limit + ) def interactive_auth_handler(orig): @@ -74,10 +74,12 @@ def interactive_auth_handler(orig): # ... yield self.auth_handler.check_auth """ + def wrapped(*args, **kwargs): res = defer.maybeDeferred(orig, *args, **kwargs) res.addErrback(_catch_incomplete_interactive_auth) return res + return wrapped diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index ab75f6c2b2..f143d8b85c 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -52,6 +52,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): if self.config.email_password_reset_behaviour == "local": from synapse.push.mailer import Mailer, load_jinja2_templates + templates = load_jinja2_templates( config=hs.config, template_html_name=hs.config.email_password_reset_template_html, @@ -72,14 +73,12 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): "User password resets have been disabled due to lack of email config" ) raise SynapseError( - 400, "Email-based password resets have been disabled on this server", + 400, "Email-based password resets have been disabled on this server" ) body = parse_json_object_from_request(request) - assert_params_in_dict(body, [ - 'client_secret', 'email', 'send_attempt' - ]) + assert_params_in_dict(body, ["client_secret", "email", "send_attempt"]) # Extract params from body client_secret = body["client_secret"] @@ -95,24 +94,24 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): ) existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( - 'email', email, + "email", email ) if existingUid is None: raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND) if self.config.email_password_reset_behaviour == "remote": - if 'id_server' not in body: + if "id_server" not in body: raise SynapseError(400, "Missing 'id_server' param in body") # Have the identity server handle the password reset flow ret = yield self.identity_handler.requestEmailToken( - body["id_server"], email, client_secret, send_attempt, next_link, + body["id_server"], email, client_secret, send_attempt, next_link ) else: # Send password reset emails from Synapse sid = yield self.send_password_reset( - email, client_secret, send_attempt, next_link, + email, client_secret, send_attempt, next_link ) # Wrap the session id in a JSON object @@ -121,13 +120,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): defer.returnValue((200, ret)) @defer.inlineCallbacks - def send_password_reset( - self, - email, - client_secret, - send_attempt, - next_link=None, - ): + def send_password_reset(self, email, client_secret, send_attempt, next_link=None): """Send a password reset email Args: @@ -144,14 +137,14 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): # Check that this email/client_secret/send_attempt combo is new or # greater than what we've seen previously session = yield self.datastore.get_threepid_validation_session( - "email", client_secret, address=email, validated=False, + "email", client_secret, address=email, validated=False ) # Check to see if a session already exists and that it is not yet # marked as validated if session and session.get("validated_at") is None: - session_id = session['session_id'] - last_send_attempt = session['last_send_attempt'] + session_id = session["session_id"] + last_send_attempt = session["last_send_attempt"] # Check that the send_attempt is higher than previous attempts if send_attempt <= last_send_attempt: @@ -169,22 +162,27 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): # and session_id try: yield self.mailer.send_password_reset_mail( - email, token, client_secret, session_id, + email, token, client_secret, session_id ) except Exception: - logger.exception( - "Error sending a password reset email to %s", email, - ) + logger.exception("Error sending a password reset email to %s", email) raise SynapseError( 500, "An error was encountered when sending the password reset email" ) - token_expires = (self.hs.clock.time_msec() + - self.config.email_validation_token_lifetime) + token_expires = ( + self.hs.clock.time_msec() + self.config.email_validation_token_lifetime + ) yield self.datastore.start_or_continue_validation_session( - "email", email, session_id, client_secret, send_attempt, - next_link, token, token_expires, + "email", + email, + session_id, + client_secret, + send_attempt, + next_link, + token, + token_expires, ) defer.returnValue(session_id) @@ -203,12 +201,12 @@ class MsisdnPasswordRequestTokenRestServlet(RestServlet): def on_POST(self, request): body = parse_json_object_from_request(request) - assert_params_in_dict(body, [ - 'id_server', 'client_secret', - 'country', 'phone_number', 'send_attempt', - ]) + assert_params_in_dict( + body, + ["id_server", "client_secret", "country", "phone_number", "send_attempt"], + ) - msisdn = phone_number_to_msisdn(body['country'], body['phone_number']) + msisdn = phone_number_to_msisdn(body["country"], body["phone_number"]) if not check_3pid_allowed(self.hs, "msisdn", msisdn): raise SynapseError( @@ -217,9 +215,7 @@ class MsisdnPasswordRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) - existingUid = yield self.datastore.get_user_id_by_threepid( - 'msisdn', msisdn - ) + existingUid = yield self.datastore.get_user_id_by_threepid("msisdn", msisdn) if existingUid is None: raise SynapseError(400, "MSISDN not found", Codes.THREEPID_NOT_FOUND) @@ -230,10 +226,9 @@ class MsisdnPasswordRequestTokenRestServlet(RestServlet): class PasswordResetSubmitTokenServlet(RestServlet): """Handles 3PID validation token submission""" + PATTERNS = client_patterns( - "/password_reset/(?P<medium>[^/]*)/submit_token/*$", - releases=(), - unstable=True, + "/password_reset/(?P<medium>[^/]*)/submit_token/*$", releases=(), unstable=True ) def __init__(self, hs): @@ -252,8 +247,7 @@ class PasswordResetSubmitTokenServlet(RestServlet): def on_GET(self, request, medium): if medium != "email": raise SynapseError( - 400, - "This medium is currently not supported for password resets", + 400, "This medium is currently not supported for password resets" ) if self.config.email_password_reset_behaviour == "off": if self.config.password_resets_were_disabled_due_to_email_config: @@ -261,7 +255,7 @@ class PasswordResetSubmitTokenServlet(RestServlet): "User password resets have been disabled due to lack of email config" ) raise SynapseError( - 400, "Email-based password resets have been disabled on this server", + 400, "Email-based password resets have been disabled on this server" ) sid = parse_string(request, "sid") @@ -272,10 +266,7 @@ class PasswordResetSubmitTokenServlet(RestServlet): try: # Mark the session as valid next_link = yield self.datastore.validate_threepid_session( - sid, - client_secret, - token, - self.clock.time_msec(), + sid, client_secret, token, self.clock.time_msec() ) # Perform a 302 redirect if next_link is set @@ -298,13 +289,11 @@ class PasswordResetSubmitTokenServlet(RestServlet): html = self.load_jinja2_template( self.config.email_template_dir, self.config.email_password_reset_failure_template, - template_vars={ - "failure_reason": e.msg, - } + template_vars={"failure_reason": e.msg}, ) request.setResponseCode(e.code) - request.write(html.encode('utf-8')) + request.write(html.encode("utf-8")) finish_request(request) defer.returnValue(None) @@ -330,20 +319,14 @@ class PasswordResetSubmitTokenServlet(RestServlet): def on_POST(self, request, medium): if medium != "email": raise SynapseError( - 400, - "This medium is currently not supported for password resets", + 400, "This medium is currently not supported for password resets" ) body = parse_json_object_from_request(request) - assert_params_in_dict(body, [ - 'sid', 'client_secret', 'token', - ]) + assert_params_in_dict(body, ["sid", "client_secret", "token"]) valid, _ = yield self.datastore.validate_threepid_validation_token( - body['sid'], - body['client_secret'], - body['token'], - self.clock.time_msec(), + body["sid"], body["client_secret"], body["token"], self.clock.time_msec() ) response_code = 200 if valid else 400 @@ -379,29 +362,30 @@ class PasswordRestServlet(RestServlet): if self.auth.has_access_token(request): requester = yield self.auth.get_user_by_req(request) params = yield self.auth_handler.validate_user_via_ui_auth( - requester, body, self.hs.get_ip_from_request(request), + requester, body, self.hs.get_ip_from_request(request) ) user_id = requester.user.to_string() else: requester = None result, params, _ = yield self.auth_handler.check_auth( [[LoginType.EMAIL_IDENTITY], [LoginType.MSISDN]], - body, self.hs.get_ip_from_request(request), + body, + self.hs.get_ip_from_request(request), password_servlet=True, ) if LoginType.EMAIL_IDENTITY in result: threepid = result[LoginType.EMAIL_IDENTITY] - if 'medium' not in threepid or 'address' not in threepid: + if "medium" not in threepid or "address" not in threepid: raise SynapseError(500, "Malformed threepid") - if threepid['medium'] == 'email': + if threepid["medium"] == "email": # For emails, transform the address to lowercase. # We store all email addreses as lowercase in the DB. # (See add_threepid in synapse/handlers/auth.py) - threepid['address'] = threepid['address'].lower() + threepid["address"] = threepid["address"].lower() # if using email, we must know about the email they're authing with! threepid_user_id = yield self.datastore.get_user_id_by_threepid( - threepid['medium'], threepid['address'] + threepid["medium"], threepid["address"] ) if not threepid_user_id: raise SynapseError(404, "Email address not found", Codes.NOT_FOUND) @@ -411,11 +395,9 @@ class PasswordRestServlet(RestServlet): raise SynapseError(500, "", Codes.UNKNOWN) assert_params_in_dict(params, ["new_password"]) - new_password = params['new_password'] + new_password = params["new_password"] - yield self._set_password_handler.set_password( - user_id, new_password, requester - ) + yield self._set_password_handler.set_password(user_id, new_password, requester) defer.returnValue((200, {})) @@ -450,25 +432,22 @@ class DeactivateAccountRestServlet(RestServlet): # allow ASes to dectivate their own users if requester.app_service: yield self._deactivate_account_handler.deactivate_account( - requester.user.to_string(), erase, + requester.user.to_string(), erase ) defer.returnValue((200, {})) yield self.auth_handler.validate_user_via_ui_auth( - requester, body, self.hs.get_ip_from_request(request), + requester, body, self.hs.get_ip_from_request(request) ) result = yield self._deactivate_account_handler.deactivate_account( - requester.user.to_string(), erase, - id_server=body.get("id_server"), + requester.user.to_string(), erase, id_server=body.get("id_server") ) if result: id_server_unbind_result = "success" else: id_server_unbind_result = "no-support" - defer.returnValue((200, { - "id_server_unbind_result": id_server_unbind_result, - })) + defer.returnValue((200, {"id_server_unbind_result": id_server_unbind_result})) class EmailThreepidRequestTokenRestServlet(RestServlet): @@ -484,11 +463,10 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): def on_POST(self, request): body = parse_json_object_from_request(request) assert_params_in_dict( - body, - ['id_server', 'client_secret', 'email', 'send_attempt'], + body, ["id_server", "client_secret", "email", "send_attempt"] ) - if not check_3pid_allowed(self.hs, "email", body['email']): + if not check_3pid_allowed(self.hs, "email", body["email"]): raise SynapseError( 403, "Your email domain is not authorized on this server", @@ -496,7 +474,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): ) existingUid = yield self.datastore.get_user_id_by_threepid( - 'email', body['email'] + "email", body["email"] ) if existingUid is not None: @@ -518,12 +496,12 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request): body = parse_json_object_from_request(request) - assert_params_in_dict(body, [ - 'id_server', 'client_secret', - 'country', 'phone_number', 'send_attempt', - ]) + assert_params_in_dict( + body, + ["id_server", "client_secret", "country", "phone_number", "send_attempt"], + ) - msisdn = phone_number_to_msisdn(body['country'], body['phone_number']) + msisdn = phone_number_to_msisdn(body["country"], body["phone_number"]) if not check_3pid_allowed(self.hs, "msisdn", msisdn): raise SynapseError( @@ -532,9 +510,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) - existingUid = yield self.datastore.get_user_id_by_threepid( - 'msisdn', msisdn - ) + existingUid = yield self.datastore.get_user_id_by_threepid("msisdn", msisdn) if existingUid is not None: raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE) @@ -558,18 +534,16 @@ class ThreepidRestServlet(RestServlet): def on_GET(self, request): requester = yield self.auth.get_user_by_req(request) - threepids = yield self.datastore.user_get_threepids( - requester.user.to_string() - ) + threepids = yield self.datastore.user_get_threepids(requester.user.to_string()) - defer.returnValue((200, {'threepids': threepids})) + defer.returnValue((200, {"threepids": threepids})) @defer.inlineCallbacks def on_POST(self, request): body = parse_json_object_from_request(request) - threePidCreds = body.get('threePidCreds') - threePidCreds = body.get('three_pid_creds', threePidCreds) + threePidCreds = body.get("threePidCreds") + threePidCreds = body.get("three_pid_creds", threePidCreds) if threePidCreds is None: raise SynapseError(400, "Missing param", Codes.MISSING_PARAM) @@ -579,30 +553,20 @@ class ThreepidRestServlet(RestServlet): threepid = yield self.identity_handler.threepid_from_creds(threePidCreds) if not threepid: - raise SynapseError( - 400, "Failed to auth 3pid", Codes.THREEPID_AUTH_FAILED - ) + raise SynapseError(400, "Failed to auth 3pid", Codes.THREEPID_AUTH_FAILED) - for reqd in ['medium', 'address', 'validated_at']: + for reqd in ["medium", "address", "validated_at"]: if reqd not in threepid: logger.warn("Couldn't add 3pid: invalid response from ID server") raise SynapseError(500, "Invalid response from ID Server") yield self.auth_handler.add_threepid( - user_id, - threepid['medium'], - threepid['address'], - threepid['validated_at'], + user_id, threepid["medium"], threepid["address"], threepid["validated_at"] ) - if 'bind' in body and body['bind']: - logger.debug( - "Binding threepid %s to %s", - threepid, user_id - ) - yield self.identity_handler.bind_threepid( - threePidCreds, user_id - ) + if "bind" in body and body["bind"]: + logger.debug("Binding threepid %s to %s", threepid, user_id) + yield self.identity_handler.bind_threepid(threePidCreds, user_id) defer.returnValue((200, {})) @@ -618,14 +582,14 @@ class ThreepidDeleteRestServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request): body = parse_json_object_from_request(request) - assert_params_in_dict(body, ['medium', 'address']) + assert_params_in_dict(body, ["medium", "address"]) requester = yield self.auth.get_user_by_req(request) user_id = requester.user.to_string() try: ret = yield self.auth_handler.delete_threepid( - user_id, body['medium'], body['address'], body.get("id_server"), + user_id, body["medium"], body["address"], body.get("id_server") ) except Exception: # NB. This endpoint should succeed if there is nothing to @@ -639,9 +603,7 @@ class ThreepidDeleteRestServlet(RestServlet): else: id_server_unbind_result = "no-support" - defer.returnValue((200, { - "id_server_unbind_result": id_server_unbind_result, - })) + defer.returnValue((200, {"id_server_unbind_result": id_server_unbind_result})) class WhoamiRestServlet(RestServlet): @@ -655,7 +617,7 @@ class WhoamiRestServlet(RestServlet): def on_GET(self, request): requester = yield self.auth.get_user_by_req(request) - defer.returnValue((200, {'user_id': requester.user.to_string()})) + defer.returnValue((200, {"user_id": requester.user.to_string()})) def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py index 574a6298ce..f155c26259 100644 --- a/synapse/rest/client/v2_alpha/account_data.py +++ b/synapse/rest/client/v2_alpha/account_data.py @@ -30,6 +30,7 @@ class AccountDataServlet(RestServlet): PUT /user/{user_id}/account_data/{account_dataType} HTTP/1.1 GET /user/{user_id}/account_data/{account_dataType} HTTP/1.1 """ + PATTERNS = client_patterns( "/user/(?P<user_id>[^/]*)/account_data/(?P<account_data_type>[^/]*)" ) @@ -52,9 +53,7 @@ class AccountDataServlet(RestServlet): user_id, account_data_type, body ) - self.notifier.on_new_event( - "account_data_key", max_id, users=[user_id] - ) + self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) defer.returnValue((200, {})) @@ -65,7 +64,7 @@ class AccountDataServlet(RestServlet): raise AuthError(403, "Cannot get account data for other users.") event = yield self.store.get_global_account_data_by_type_for_user( - account_data_type, user_id, + account_data_type, user_id ) if event is None: @@ -79,6 +78,7 @@ class RoomAccountDataServlet(RestServlet): PUT /user/{user_id}/rooms/{room_id}/account_data/{account_dataType} HTTP/1.1 GET /user/{user_id}/rooms/{room_id}/account_data/{account_dataType} HTTP/1.1 """ + PATTERNS = client_patterns( "/user/(?P<user_id>[^/]*)" "/rooms/(?P<room_id>[^/]*)" @@ -103,16 +103,14 @@ class RoomAccountDataServlet(RestServlet): raise SynapseError( 405, "Cannot set m.fully_read through this API." - " Use /rooms/!roomId:server.name/read_markers" + " Use /rooms/!roomId:server.name/read_markers", ) max_id = yield self.store.add_account_data_to_room( user_id, room_id, account_data_type, body ) - self.notifier.on_new_event( - "account_data_key", max_id, users=[user_id] - ) + self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) defer.returnValue((200, {})) @@ -123,7 +121,7 @@ class RoomAccountDataServlet(RestServlet): raise AuthError(403, "Cannot get account data for other users.") event = yield self.store.get_account_data_for_room_and_type( - user_id, room_id, account_data_type, + user_id, room_id, account_data_type ) if event is None: diff --git a/synapse/rest/client/v2_alpha/account_validity.py b/synapse/rest/client/v2_alpha/account_validity.py index 63bdc33564..d29c10b83d 100644 --- a/synapse/rest/client/v2_alpha/account_validity.py +++ b/synapse/rest/client/v2_alpha/account_validity.py @@ -28,7 +28,9 @@ logger = logging.getLogger(__name__) class AccountValidityRenewServlet(RestServlet): PATTERNS = client_patterns("/account_validity/renew$") - SUCCESS_HTML = b"<html><body>Your account has been successfully renewed.</body><html>" + SUCCESS_HTML = ( + b"<html><body>Your account has been successfully renewed.</body><html>" + ) def __init__(self, hs): """ @@ -47,13 +49,13 @@ class AccountValidityRenewServlet(RestServlet): raise SynapseError(400, "Missing renewal token") renewal_token = request.args[b"token"][0] - yield self.account_activity_handler.renew_account(renewal_token.decode('utf8')) + yield self.account_activity_handler.renew_account(renewal_token.decode("utf8")) request.setResponseCode(200) request.setHeader(b"Content-Type", b"text/html; charset=utf-8") - request.setHeader(b"Content-Length", b"%d" % ( - len(AccountValidityRenewServlet.SUCCESS_HTML), - )) + request.setHeader( + b"Content-Length", b"%d" % (len(AccountValidityRenewServlet.SUCCESS_HTML),) + ) request.write(AccountValidityRenewServlet.SUCCESS_HTML) finish_request(request) defer.returnValue(None) @@ -77,7 +79,9 @@ class AccountValiditySendMailServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request): if not self.account_validity.renew_by_email_enabled: - raise AuthError(403, "Account renewal via email is disabled on this server.") + raise AuthError( + 403, "Account renewal via email is disabled on this server." + ) requester = yield self.auth.get_user_by_req(request, allow_expired=True) user_id = requester.user.to_string() diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py index 8dfe5cba02..bebc2951e7 100644 --- a/synapse/rest/client/v2_alpha/auth.py +++ b/synapse/rest/client/v2_alpha/auth.py @@ -122,6 +122,7 @@ class AuthRestServlet(RestServlet): cannot be handled in the normal flow (with requests to the same endpoint). Current use is for web fallback auth. """ + PATTERNS = client_patterns(r"/auth/(?P<stagetype>[\w\.]*)/fallback/web") def __init__(self, hs): @@ -138,11 +139,10 @@ class AuthRestServlet(RestServlet): if stagetype == LoginType.RECAPTCHA: html = RECAPTCHA_TEMPLATE % { - 'session': session, - 'myurl': "%s/r0/auth/%s/fallback/web" % ( - CLIENT_API_PREFIX, LoginType.RECAPTCHA - ), - 'sitekey': self.hs.config.recaptcha_public_key, + "session": session, + "myurl": "%s/r0/auth/%s/fallback/web" + % (CLIENT_API_PREFIX, LoginType.RECAPTCHA), + "sitekey": self.hs.config.recaptcha_public_key, } html_bytes = html.encode("utf8") request.setResponseCode(200) @@ -154,14 +154,11 @@ class AuthRestServlet(RestServlet): return None elif stagetype == LoginType.TERMS: html = TERMS_TEMPLATE % { - 'session': session, - 'terms_url': "%s_matrix/consent?v=%s" % ( - self.hs.config.public_baseurl, - self.hs.config.user_consent_version, - ), - 'myurl': "%s/r0/auth/%s/fallback/web" % ( - CLIENT_API_PREFIX, LoginType.TERMS - ), + "session": session, + "terms_url": "%s_matrix/consent?v=%s" + % (self.hs.config.public_baseurl, self.hs.config.user_consent_version), + "myurl": "%s/r0/auth/%s/fallback/web" + % (CLIENT_API_PREFIX, LoginType.TERMS), } html_bytes = html.encode("utf8") request.setResponseCode(200) @@ -187,26 +184,20 @@ class AuthRestServlet(RestServlet): if not response: raise SynapseError(400, "No captcha response supplied") - authdict = { - 'response': response, - 'session': session, - } + authdict = {"response": response, "session": session} success = yield self.auth_handler.add_oob_auth( - LoginType.RECAPTCHA, - authdict, - self.hs.get_ip_from_request(request) + LoginType.RECAPTCHA, authdict, self.hs.get_ip_from_request(request) ) if success: html = SUCCESS_TEMPLATE else: html = RECAPTCHA_TEMPLATE % { - 'session': session, - 'myurl': "%s/r0/auth/%s/fallback/web" % ( - CLIENT_API_PREFIX, LoginType.RECAPTCHA - ), - 'sitekey': self.hs.config.recaptcha_public_key, + "session": session, + "myurl": "%s/r0/auth/%s/fallback/web" + % (CLIENT_API_PREFIX, LoginType.RECAPTCHA), + "sitekey": self.hs.config.recaptcha_public_key, } html_bytes = html.encode("utf8") request.setResponseCode(200) @@ -218,31 +209,28 @@ class AuthRestServlet(RestServlet): defer.returnValue(None) elif stagetype == LoginType.TERMS: - if ('session' not in request.args or - len(request.args['session'])) == 0: + if ("session" not in request.args or len(request.args["session"])) == 0: raise SynapseError(400, "No session supplied") - session = request.args['session'][0] - authdict = {'session': session} + session = request.args["session"][0] + authdict = {"session": session} success = yield self.auth_handler.add_oob_auth( - LoginType.TERMS, - authdict, - self.hs.get_ip_from_request(request) + LoginType.TERMS, authdict, self.hs.get_ip_from_request(request) ) if success: html = SUCCESS_TEMPLATE else: html = TERMS_TEMPLATE % { - 'session': session, - 'terms_url': "%s_matrix/consent?v=%s" % ( + "session": session, + "terms_url": "%s_matrix/consent?v=%s" + % ( self.hs.config.public_baseurl, self.hs.config.user_consent_version, ), - 'myurl': "%s/r0/auth/%s/fallback/web" % ( - CLIENT_API_PREFIX, LoginType.TERMS - ), + "myurl": "%s/r0/auth/%s/fallback/web" + % (CLIENT_API_PREFIX, LoginType.TERMS), } html_bytes = html.encode("utf8") request.setResponseCode(200) diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py index 78665304a5..d279229d74 100644 --- a/synapse/rest/client/v2_alpha/devices.py +++ b/synapse/rest/client/v2_alpha/devices.py @@ -56,6 +56,7 @@ class DeleteDevicesRestServlet(RestServlet): API for bulk deletion of devices. Accepts a JSON object with a devices key which lists the device_ids to delete. Requires user interactive auth. """ + PATTERNS = client_patterns("/delete_devices") def __init__(self, hs): @@ -84,12 +85,11 @@ class DeleteDevicesRestServlet(RestServlet): assert_params_in_dict(body, ["devices"]) yield self.auth_handler.validate_user_via_ui_auth( - requester, body, self.hs.get_ip_from_request(request), + requester, body, self.hs.get_ip_from_request(request) ) yield self.device_handler.delete_devices( - requester.user.to_string(), - body['devices'], + requester.user.to_string(), body["devices"] ) defer.returnValue((200, {})) @@ -112,8 +112,7 @@ class DeviceRestServlet(RestServlet): def on_GET(self, request, device_id): requester = yield self.auth.get_user_by_req(request, allow_guest=True) device = yield self.device_handler.get_device( - requester.user.to_string(), - device_id, + requester.user.to_string(), device_id ) defer.returnValue((200, device)) @@ -134,12 +133,10 @@ class DeviceRestServlet(RestServlet): raise yield self.auth_handler.validate_user_via_ui_auth( - requester, body, self.hs.get_ip_from_request(request), + requester, body, self.hs.get_ip_from_request(request) ) - yield self.device_handler.delete_device( - requester.user.to_string(), device_id, - ) + yield self.device_handler.delete_device(requester.user.to_string(), device_id) defer.returnValue((200, {})) @defer.inlineCallbacks @@ -148,9 +145,7 @@ class DeviceRestServlet(RestServlet): body = parse_json_object_from_request(request) yield self.device_handler.update_device( - requester.user.to_string(), - device_id, - body + requester.user.to_string(), device_id, body ) defer.returnValue((200, {})) diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py index 65db48c3cc..3f0adf4a21 100644 --- a/synapse/rest/client/v2_alpha/filter.py +++ b/synapse/rest/client/v2_alpha/filter.py @@ -53,8 +53,7 @@ class GetFilterRestServlet(RestServlet): try: filter = yield self.filtering.get_user_filter( - user_localpart=target_user.localpart, - filter_id=filter_id, + user_localpart=target_user.localpart, filter_id=filter_id ) defer.returnValue((200, filter.get_filter_json())) @@ -84,14 +83,10 @@ class CreateFilterRestServlet(RestServlet): raise AuthError(403, "Can only create filters for local users") content = parse_json_object_from_request(request) - set_timeline_upper_limit( - content, - self.hs.config.filter_timeline_limit - ) + set_timeline_upper_limit(content, self.hs.config.filter_timeline_limit) filter_id = yield self.filtering.add_user_filter( - user_localpart=target_user.localpart, - user_filter=content, + user_localpart=target_user.localpart, user_filter=content ) defer.returnValue((200, {"filter_id": str(filter_id)})) diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py index d082385ec7..a312dd2593 100644 --- a/synapse/rest/client/v2_alpha/groups.py +++ b/synapse/rest/client/v2_alpha/groups.py @@ -29,6 +29,7 @@ logger = logging.getLogger(__name__) class GroupServlet(RestServlet): """Get the group profile """ + PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/profile$") def __init__(self, hs): @@ -43,8 +44,7 @@ class GroupServlet(RestServlet): requester_user_id = requester.user.to_string() group_description = yield self.groups_handler.get_group_profile( - group_id, - requester_user_id, + group_id, requester_user_id ) defer.returnValue((200, group_description)) @@ -56,7 +56,7 @@ class GroupServlet(RestServlet): content = parse_json_object_from_request(request) yield self.groups_handler.update_group_profile( - group_id, requester_user_id, content, + group_id, requester_user_id, content ) defer.returnValue((200, {})) @@ -65,6 +65,7 @@ class GroupServlet(RestServlet): class GroupSummaryServlet(RestServlet): """Get the full group summary """ + PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/summary$") def __init__(self, hs): @@ -79,8 +80,7 @@ class GroupSummaryServlet(RestServlet): requester_user_id = requester.user.to_string() get_group_summary = yield self.groups_handler.get_group_summary( - group_id, - requester_user_id, + group_id, requester_user_id ) defer.returnValue((200, get_group_summary)) @@ -93,6 +93,7 @@ class GroupSummaryRoomsCatServlet(RestServlet): - /groups/:group/summary/rooms/:room_id - /groups/:group/summary/categories/:category/rooms/:room_id """ + PATTERNS = client_patterns( "/groups/(?P<group_id>[^/]*)/summary" "(/categories/(?P<category_id>[^/]+))?" @@ -112,7 +113,8 @@ class GroupSummaryRoomsCatServlet(RestServlet): content = parse_json_object_from_request(request) resp = yield self.groups_handler.update_group_summary_room( - group_id, requester_user_id, + group_id, + requester_user_id, room_id=room_id, category_id=category_id, content=content, @@ -126,9 +128,7 @@ class GroupSummaryRoomsCatServlet(RestServlet): requester_user_id = requester.user.to_string() resp = yield self.groups_handler.delete_group_summary_room( - group_id, requester_user_id, - room_id=room_id, - category_id=category_id, + group_id, requester_user_id, room_id=room_id, category_id=category_id ) defer.returnValue((200, resp)) @@ -137,6 +137,7 @@ class GroupSummaryRoomsCatServlet(RestServlet): class GroupCategoryServlet(RestServlet): """Get/add/update/delete a group category """ + PATTERNS = client_patterns( "/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)$" ) @@ -153,8 +154,7 @@ class GroupCategoryServlet(RestServlet): requester_user_id = requester.user.to_string() category = yield self.groups_handler.get_group_category( - group_id, requester_user_id, - category_id=category_id, + group_id, requester_user_id, category_id=category_id ) defer.returnValue((200, category)) @@ -166,9 +166,7 @@ class GroupCategoryServlet(RestServlet): content = parse_json_object_from_request(request) resp = yield self.groups_handler.update_group_category( - group_id, requester_user_id, - category_id=category_id, - content=content, + group_id, requester_user_id, category_id=category_id, content=content ) defer.returnValue((200, resp)) @@ -179,8 +177,7 @@ class GroupCategoryServlet(RestServlet): requester_user_id = requester.user.to_string() resp = yield self.groups_handler.delete_group_category( - group_id, requester_user_id, - category_id=category_id, + group_id, requester_user_id, category_id=category_id ) defer.returnValue((200, resp)) @@ -189,9 +186,8 @@ class GroupCategoryServlet(RestServlet): class GroupCategoriesServlet(RestServlet): """Get all group categories """ - PATTERNS = client_patterns( - "/groups/(?P<group_id>[^/]*)/categories/$" - ) + + PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/categories/$") def __init__(self, hs): super(GroupCategoriesServlet, self).__init__() @@ -205,7 +201,7 @@ class GroupCategoriesServlet(RestServlet): requester_user_id = requester.user.to_string() category = yield self.groups_handler.get_group_categories( - group_id, requester_user_id, + group_id, requester_user_id ) defer.returnValue((200, category)) @@ -214,9 +210,8 @@ class GroupCategoriesServlet(RestServlet): class GroupRoleServlet(RestServlet): """Get/add/update/delete a group role """ - PATTERNS = client_patterns( - "/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)$" - ) + + PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)$") def __init__(self, hs): super(GroupRoleServlet, self).__init__() @@ -230,8 +225,7 @@ class GroupRoleServlet(RestServlet): requester_user_id = requester.user.to_string() category = yield self.groups_handler.get_group_role( - group_id, requester_user_id, - role_id=role_id, + group_id, requester_user_id, role_id=role_id ) defer.returnValue((200, category)) @@ -243,9 +237,7 @@ class GroupRoleServlet(RestServlet): content = parse_json_object_from_request(request) resp = yield self.groups_handler.update_group_role( - group_id, requester_user_id, - role_id=role_id, - content=content, + group_id, requester_user_id, role_id=role_id, content=content ) defer.returnValue((200, resp)) @@ -256,8 +248,7 @@ class GroupRoleServlet(RestServlet): requester_user_id = requester.user.to_string() resp = yield self.groups_handler.delete_group_role( - group_id, requester_user_id, - role_id=role_id, + group_id, requester_user_id, role_id=role_id ) defer.returnValue((200, resp)) @@ -266,9 +257,8 @@ class GroupRoleServlet(RestServlet): class GroupRolesServlet(RestServlet): """Get all group roles """ - PATTERNS = client_patterns( - "/groups/(?P<group_id>[^/]*)/roles/$" - ) + + PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/roles/$") def __init__(self, hs): super(GroupRolesServlet, self).__init__() @@ -282,7 +272,7 @@ class GroupRolesServlet(RestServlet): requester_user_id = requester.user.to_string() category = yield self.groups_handler.get_group_roles( - group_id, requester_user_id, + group_id, requester_user_id ) defer.returnValue((200, category)) @@ -295,6 +285,7 @@ class GroupSummaryUsersRoleServlet(RestServlet): - /groups/:group/summary/users/:room_id - /groups/:group/summary/roles/:role/users/:user_id """ + PATTERNS = client_patterns( "/groups/(?P<group_id>[^/]*)/summary" "(/roles/(?P<role_id>[^/]+))?" @@ -314,7 +305,8 @@ class GroupSummaryUsersRoleServlet(RestServlet): content = parse_json_object_from_request(request) resp = yield self.groups_handler.update_group_summary_user( - group_id, requester_user_id, + group_id, + requester_user_id, user_id=user_id, role_id=role_id, content=content, @@ -328,9 +320,7 @@ class GroupSummaryUsersRoleServlet(RestServlet): requester_user_id = requester.user.to_string() resp = yield self.groups_handler.delete_group_summary_user( - group_id, requester_user_id, - user_id=user_id, - role_id=role_id, + group_id, requester_user_id, user_id=user_id, role_id=role_id ) defer.returnValue((200, resp)) @@ -339,6 +329,7 @@ class GroupSummaryUsersRoleServlet(RestServlet): class GroupRoomServlet(RestServlet): """Get all rooms in a group """ + PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/rooms$") def __init__(self, hs): @@ -352,7 +343,9 @@ class GroupRoomServlet(RestServlet): requester = yield self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() - result = yield self.groups_handler.get_rooms_in_group(group_id, requester_user_id) + result = yield self.groups_handler.get_rooms_in_group( + group_id, requester_user_id + ) defer.returnValue((200, result)) @@ -360,6 +353,7 @@ class GroupRoomServlet(RestServlet): class GroupUsersServlet(RestServlet): """Get all users in a group """ + PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/users$") def __init__(self, hs): @@ -373,7 +367,9 @@ class GroupUsersServlet(RestServlet): requester = yield self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() - result = yield self.groups_handler.get_users_in_group(group_id, requester_user_id) + result = yield self.groups_handler.get_users_in_group( + group_id, requester_user_id + ) defer.returnValue((200, result)) @@ -381,6 +377,7 @@ class GroupUsersServlet(RestServlet): class GroupInvitedUsersServlet(RestServlet): """Get users invited to a group """ + PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/invited_users$") def __init__(self, hs): @@ -395,8 +392,7 @@ class GroupInvitedUsersServlet(RestServlet): requester_user_id = requester.user.to_string() result = yield self.groups_handler.get_invited_users_in_group( - group_id, - requester_user_id, + group_id, requester_user_id ) defer.returnValue((200, result)) @@ -405,6 +401,7 @@ class GroupInvitedUsersServlet(RestServlet): class GroupSettingJoinPolicyServlet(RestServlet): """Set group join policy """ + PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/settings/m.join_policy$") def __init__(self, hs): @@ -420,9 +417,7 @@ class GroupSettingJoinPolicyServlet(RestServlet): content = parse_json_object_from_request(request) result = yield self.groups_handler.set_group_join_policy( - group_id, - requester_user_id, - content, + group_id, requester_user_id, content ) defer.returnValue((200, result)) @@ -431,6 +426,7 @@ class GroupSettingJoinPolicyServlet(RestServlet): class GroupCreateServlet(RestServlet): """Create a group """ + PATTERNS = client_patterns("/create_group$") def __init__(self, hs): @@ -451,9 +447,7 @@ class GroupCreateServlet(RestServlet): group_id = GroupID(localpart, self.server_name).to_string() result = yield self.groups_handler.create_group( - group_id, - requester_user_id, - content, + group_id, requester_user_id, content ) defer.returnValue((200, result)) @@ -462,6 +456,7 @@ class GroupCreateServlet(RestServlet): class GroupAdminRoomsServlet(RestServlet): """Add a room to the group """ + PATTERNS = client_patterns( "/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)$" ) @@ -479,7 +474,7 @@ class GroupAdminRoomsServlet(RestServlet): content = parse_json_object_from_request(request) result = yield self.groups_handler.add_room_to_group( - group_id, requester_user_id, room_id, content, + group_id, requester_user_id, room_id, content ) defer.returnValue((200, result)) @@ -490,7 +485,7 @@ class GroupAdminRoomsServlet(RestServlet): requester_user_id = requester.user.to_string() result = yield self.groups_handler.remove_room_from_group( - group_id, requester_user_id, room_id, + group_id, requester_user_id, room_id ) defer.returnValue((200, result)) @@ -499,6 +494,7 @@ class GroupAdminRoomsServlet(RestServlet): class GroupAdminRoomsConfigServlet(RestServlet): """Update the config of a room in a group """ + PATTERNS = client_patterns( "/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)" "/config/(?P<config_key>[^/]*)$" @@ -517,7 +513,7 @@ class GroupAdminRoomsConfigServlet(RestServlet): content = parse_json_object_from_request(request) result = yield self.groups_handler.update_room_in_group( - group_id, requester_user_id, room_id, config_key, content, + group_id, requester_user_id, room_id, config_key, content ) defer.returnValue((200, result)) @@ -526,6 +522,7 @@ class GroupAdminRoomsConfigServlet(RestServlet): class GroupAdminUsersInviteServlet(RestServlet): """Invite a user to the group """ + PATTERNS = client_patterns( "/groups/(?P<group_id>[^/]*)/admin/users/invite/(?P<user_id>[^/]*)$" ) @@ -546,7 +543,7 @@ class GroupAdminUsersInviteServlet(RestServlet): content = parse_json_object_from_request(request) config = content.get("config", {}) result = yield self.groups_handler.invite( - group_id, user_id, requester_user_id, config, + group_id, user_id, requester_user_id, config ) defer.returnValue((200, result)) @@ -555,6 +552,7 @@ class GroupAdminUsersInviteServlet(RestServlet): class GroupAdminUsersKickServlet(RestServlet): """Kick a user from the group """ + PATTERNS = client_patterns( "/groups/(?P<group_id>[^/]*)/admin/users/remove/(?P<user_id>[^/]*)$" ) @@ -572,7 +570,7 @@ class GroupAdminUsersKickServlet(RestServlet): content = parse_json_object_from_request(request) result = yield self.groups_handler.remove_user_from_group( - group_id, user_id, requester_user_id, content, + group_id, user_id, requester_user_id, content ) defer.returnValue((200, result)) @@ -581,9 +579,8 @@ class GroupAdminUsersKickServlet(RestServlet): class GroupSelfLeaveServlet(RestServlet): """Leave a joined group """ - PATTERNS = client_patterns( - "/groups/(?P<group_id>[^/]*)/self/leave$" - ) + + PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/leave$") def __init__(self, hs): super(GroupSelfLeaveServlet, self).__init__() @@ -598,7 +595,7 @@ class GroupSelfLeaveServlet(RestServlet): content = parse_json_object_from_request(request) result = yield self.groups_handler.remove_user_from_group( - group_id, requester_user_id, requester_user_id, content, + group_id, requester_user_id, requester_user_id, content ) defer.returnValue((200, result)) @@ -607,9 +604,8 @@ class GroupSelfLeaveServlet(RestServlet): class GroupSelfJoinServlet(RestServlet): """Attempt to join a group, or knock """ - PATTERNS = client_patterns( - "/groups/(?P<group_id>[^/]*)/self/join$" - ) + + PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/join$") def __init__(self, hs): super(GroupSelfJoinServlet, self).__init__() @@ -624,7 +620,7 @@ class GroupSelfJoinServlet(RestServlet): content = parse_json_object_from_request(request) result = yield self.groups_handler.join_group( - group_id, requester_user_id, content, + group_id, requester_user_id, content ) defer.returnValue((200, result)) @@ -633,9 +629,8 @@ class GroupSelfJoinServlet(RestServlet): class GroupSelfAcceptInviteServlet(RestServlet): """Accept a group invite """ - PATTERNS = client_patterns( - "/groups/(?P<group_id>[^/]*)/self/accept_invite$" - ) + + PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/accept_invite$") def __init__(self, hs): super(GroupSelfAcceptInviteServlet, self).__init__() @@ -650,7 +645,7 @@ class GroupSelfAcceptInviteServlet(RestServlet): content = parse_json_object_from_request(request) result = yield self.groups_handler.accept_invite( - group_id, requester_user_id, content, + group_id, requester_user_id, content ) defer.returnValue((200, result)) @@ -659,9 +654,8 @@ class GroupSelfAcceptInviteServlet(RestServlet): class GroupSelfUpdatePublicityServlet(RestServlet): """Update whether we publicise a users membership of a group """ - PATTERNS = client_patterns( - "/groups/(?P<group_id>[^/]*)/self/update_publicity$" - ) + + PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/update_publicity$") def __init__(self, hs): super(GroupSelfUpdatePublicityServlet, self).__init__() @@ -676,9 +670,7 @@ class GroupSelfUpdatePublicityServlet(RestServlet): content = parse_json_object_from_request(request) publicise = content["publicise"] - yield self.store.update_group_publicity( - group_id, requester_user_id, publicise, - ) + yield self.store.update_group_publicity(group_id, requester_user_id, publicise) defer.returnValue((200, {})) @@ -686,9 +678,8 @@ class GroupSelfUpdatePublicityServlet(RestServlet): class PublicisedGroupsForUserServlet(RestServlet): """Get the list of groups a user is advertising """ - PATTERNS = client_patterns( - "/publicised_groups/(?P<user_id>[^/]*)$" - ) + + PATTERNS = client_patterns("/publicised_groups/(?P<user_id>[^/]*)$") def __init__(self, hs): super(PublicisedGroupsForUserServlet, self).__init__() @@ -701,9 +692,7 @@ class PublicisedGroupsForUserServlet(RestServlet): def on_GET(self, request, user_id): yield self.auth.get_user_by_req(request, allow_guest=True) - result = yield self.groups_handler.get_publicised_groups_for_user( - user_id - ) + result = yield self.groups_handler.get_publicised_groups_for_user(user_id) defer.returnValue((200, result)) @@ -711,9 +700,8 @@ class PublicisedGroupsForUserServlet(RestServlet): class PublicisedGroupsForUsersServlet(RestServlet): """Get the list of groups a user is advertising """ - PATTERNS = client_patterns( - "/publicised_groups$" - ) + + PATTERNS = client_patterns("/publicised_groups$") def __init__(self, hs): super(PublicisedGroupsForUsersServlet, self).__init__() @@ -729,9 +717,7 @@ class PublicisedGroupsForUsersServlet(RestServlet): content = parse_json_object_from_request(request) user_ids = content["user_ids"] - result = yield self.groups_handler.bulk_get_publicised_groups( - user_ids - ) + result = yield self.groups_handler.bulk_get_publicised_groups(user_ids) defer.returnValue((200, result)) @@ -739,9 +725,8 @@ class PublicisedGroupsForUsersServlet(RestServlet): class GroupsForUserServlet(RestServlet): """Get all groups the logged in user is joined to """ - PATTERNS = client_patterns( - "/joined_groups$" - ) + + PATTERNS = client_patterns("/joined_groups$") def __init__(self, hs): super(GroupsForUserServlet, self).__init__() diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 4cbfbf5631..45c9928b65 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -56,6 +56,7 @@ class KeyUploadServlet(RestServlet): }, } """ + PATTERNS = client_patterns("/keys/upload(/(?P<device_id>[^/]+))?$") def __init__(self, hs): @@ -76,18 +77,19 @@ class KeyUploadServlet(RestServlet): if device_id is not None: # passing the device_id here is deprecated; however, we allow it # for now for compatibility with older clients. - if (requester.device_id is not None and - device_id != requester.device_id): - logger.warning("Client uploading keys for a different device " - "(logged in as %s, uploading for %s)", - requester.device_id, device_id) + if requester.device_id is not None and device_id != requester.device_id: + logger.warning( + "Client uploading keys for a different device " + "(logged in as %s, uploading for %s)", + requester.device_id, + device_id, + ) else: device_id = requester.device_id if device_id is None: raise SynapseError( - 400, - "To upload keys, you must pass device_id when authenticating" + 400, "To upload keys, you must pass device_id when authenticating" ) result = yield self.e2e_keys_handler.upload_keys_for_user( @@ -159,6 +161,7 @@ class KeyChangesServlet(RestServlet): 200 OK { "changed": ["@foo:example.com"] } """ + PATTERNS = client_patterns("/keys/changes$") def __init__(self, hs): @@ -184,9 +187,7 @@ class KeyChangesServlet(RestServlet): user_id = requester.user.to_string() - results = yield self.device_handler.get_user_ids_changed( - user_id, from_token, - ) + results = yield self.device_handler.get_user_ids_changed(user_id, from_token) defer.returnValue((200, results)) @@ -209,6 +210,7 @@ class OneTimeKeyServlet(RestServlet): } } } } """ + PATTERNS = client_patterns("/keys/claim$") def __init__(self, hs): @@ -221,10 +223,7 @@ class OneTimeKeyServlet(RestServlet): yield self.auth.get_user_by_req(request, allow_guest=True) timeout = parse_integer(request, "timeout", 10 * 1000) body = parse_json_object_from_request(request) - result = yield self.e2e_keys_handler.claim_one_time_keys( - body, - timeout, - ) + result = yield self.e2e_keys_handler.claim_one_time_keys(body, timeout) defer.returnValue((200, result)) diff --git a/synapse/rest/client/v2_alpha/notifications.py b/synapse/rest/client/v2_alpha/notifications.py index 53e666989b..728a52328f 100644 --- a/synapse/rest/client/v2_alpha/notifications.py +++ b/synapse/rest/client/v2_alpha/notifications.py @@ -51,7 +51,7 @@ class NotificationsServlet(RestServlet): ) receipts_by_room = yield self.store.get_receipts_for_user_with_orderings( - user_id, 'm.read' + user_id, "m.read" ) notif_event_ids = [pa["event_id"] for pa in push_actions] @@ -67,11 +67,13 @@ class NotificationsServlet(RestServlet): "profile_tag": pa["profile_tag"], "actions": pa["actions"], "ts": pa["received_ts"], - "event": (yield self._event_serializer.serialize_event( - notif_events[pa["event_id"]], - self.clock.time_msec(), - event_format=format_event_for_client_v2_without_room_id, - )), + "event": ( + yield self._event_serializer.serialize_event( + notif_events[pa["event_id"]], + self.clock.time_msec(), + event_format=format_event_for_client_v2_without_room_id, + ) + ), } if pa["room_id"] not in receipts_by_room: @@ -80,17 +82,15 @@ class NotificationsServlet(RestServlet): receipt = receipts_by_room[pa["room_id"]] returned_pa["read"] = ( - receipt["topological_ordering"], receipt["stream_ordering"] - ) >= ( - pa["topological_ordering"], pa["stream_ordering"] - ) + receipt["topological_ordering"], + receipt["stream_ordering"], + ) >= (pa["topological_ordering"], pa["stream_ordering"]) returned_push_actions.append(returned_pa) next_token = str(pa["stream_ordering"]) - defer.returnValue((200, { - "notifications": returned_push_actions, - "next_token": next_token, - })) + defer.returnValue( + (200, {"notifications": returned_push_actions, "next_token": next_token}) + ) def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v2_alpha/openid.py b/synapse/rest/client/v2_alpha/openid.py index bb927d9f9d..b1b5385b09 100644 --- a/synapse/rest/client/v2_alpha/openid.py +++ b/synapse/rest/client/v2_alpha/openid.py @@ -56,9 +56,8 @@ class IdTokenServlet(RestServlet): "expires_in": 3600, } """ - PATTERNS = client_patterns( - "/user/(?P<user_id>[^/]*)/openid/request_token" - ) + + PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/openid/request_token") EXPIRES_MS = 3600 * 1000 @@ -84,12 +83,17 @@ class IdTokenServlet(RestServlet): yield self.store.insert_open_id_token(token, ts_valid_until_ms, user_id) - defer.returnValue((200, { - "access_token": token, - "token_type": "Bearer", - "matrix_server_name": self.server_name, - "expires_in": self.EXPIRES_MS / 1000, - })) + defer.returnValue( + ( + 200, + { + "access_token": token, + "token_type": "Bearer", + "matrix_server_name": self.server_name, + "expires_in": self.EXPIRES_MS / 1000, + }, + ) + ) def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v2_alpha/read_marker.py b/synapse/rest/client/v2_alpha/read_marker.py index f4bd0d077f..e75664279b 100644 --- a/synapse/rest/client/v2_alpha/read_marker.py +++ b/synapse/rest/client/v2_alpha/read_marker.py @@ -48,7 +48,7 @@ class ReadMarkerRestServlet(RestServlet): room_id, "m.read", user_id=requester.user.to_string(), - event_id=read_event_id + event_id=read_event_id, ) read_marker_event_id = body.get("m.fully_read", None) @@ -56,7 +56,7 @@ class ReadMarkerRestServlet(RestServlet): yield self.read_marker_handler.received_client_read_marker( room_id, user_id=requester.user.to_string(), - event_id=read_marker_event_id + event_id=read_marker_event_id, ) defer.returnValue((200, {})) diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py index fa12ac3e4d..488905626a 100644 --- a/synapse/rest/client/v2_alpha/receipts.py +++ b/synapse/rest/client/v2_alpha/receipts.py @@ -49,10 +49,7 @@ class ReceiptRestServlet(RestServlet): yield self.presence_handler.bump_presence_active_time(requester.user) yield self.receipts_handler.received_client_receipt( - room_id, - receipt_type, - user_id=requester.user.to_string(), - event_id=event_id + room_id, receipt_type, user_id=requester.user.to_string(), event_id=event_id ) defer.returnValue((200, {})) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 79c085408b..5c120e4dd5 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -52,6 +52,7 @@ from ._base import client_patterns, interactive_auth_handler if hasattr(hmac, "compare_digest"): compare_digest = hmac.compare_digest else: + def compare_digest(a, b): return a == b @@ -75,11 +76,11 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): def on_POST(self, request): body = parse_json_object_from_request(request) - assert_params_in_dict(body, [ - 'id_server', 'client_secret', 'email', 'send_attempt' - ]) + assert_params_in_dict( + body, ["id_server", "client_secret", "email", "send_attempt"] + ) - if not check_3pid_allowed(self.hs, "email", body['email']): + if not check_3pid_allowed(self.hs, "email", body["email"]): raise SynapseError( 403, "Your email domain is not authorized to register on this server", @@ -87,7 +88,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): ) existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( - 'email', body['email'] + "email", body["email"] ) if existingUid is not None: @@ -113,13 +114,12 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): def on_POST(self, request): body = parse_json_object_from_request(request) - assert_params_in_dict(body, [ - 'id_server', 'client_secret', - 'country', 'phone_number', - 'send_attempt', - ]) + assert_params_in_dict( + body, + ["id_server", "client_secret", "country", "phone_number", "send_attempt"], + ) - msisdn = phone_number_to_msisdn(body['country'], body['phone_number']) + msisdn = phone_number_to_msisdn(body["country"], body["phone_number"]) if not check_3pid_allowed(self.hs, "msisdn", msisdn): raise SynapseError( @@ -129,7 +129,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): ) existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( - 'msisdn', msisdn + "msisdn", msisdn ) if existingUid is not None: @@ -165,7 +165,7 @@ class UsernameAvailabilityRestServlet(RestServlet): reject_limit=1, # Allow 1 request at a time concurrent_requests=1, - ) + ), ) @defer.inlineCallbacks @@ -212,7 +212,8 @@ class RegisterRestServlet(RestServlet): time_now = self.clock.time() allowed, time_allowed = self.ratelimiter.can_do_action( - client_addr, time_now_s=time_now, + client_addr, + time_now_s=time_now, rate_hz=self.hs.config.rc_registration.per_second, burst_count=self.hs.config.rc_registration.burst_count, update=False, @@ -220,7 +221,7 @@ class RegisterRestServlet(RestServlet): if not allowed: raise LimitExceededError( - retry_after_ms=int(1000 * (time_allowed - time_now)), + retry_after_ms=int(1000 * (time_allowed - time_now)) ) kind = b"user" @@ -239,18 +240,22 @@ class RegisterRestServlet(RestServlet): # we do basic sanity checks here because the auth layer will store these # in sessions. Pull out the username/password provided to us. desired_password = None - if 'password' in body: - if (not isinstance(body['password'], string_types) or - len(body['password']) > 512): + if "password" in body: + if ( + not isinstance(body["password"], string_types) + or len(body["password"]) > 512 + ): raise SynapseError(400, "Invalid password") desired_password = body["password"] desired_username = None - if 'username' in body: - if (not isinstance(body['username'], string_types) or - len(body['username']) > 512): + if "username" in body: + if ( + not isinstance(body["username"], string_types) + or len(body["username"]) > 512 + ): raise SynapseError(400, "Invalid username") - desired_username = body['username'] + desired_username = body["username"] appservice = None if self.auth.has_access_token(request): @@ -290,7 +295,7 @@ class RegisterRestServlet(RestServlet): desired_username = desired_username.lower() # == Shared Secret Registration == (e.g. create new user scripts) - if 'mac' in body: + if "mac" in body: # FIXME: Should we really be determining if this is shared secret # auth based purely on the 'mac' key? result = yield self._do_shared_secret_registration( @@ -305,16 +310,13 @@ class RegisterRestServlet(RestServlet): guest_access_token = body.get("guest_access_token", None) - if ( - 'initial_device_display_name' in body and - 'password' not in body - ): + if "initial_device_display_name" in body and "password" not in body: # ignore 'initial_device_display_name' if sent without # a password to work around a client bug where it sent # the 'initial_device_display_name' param alone, wiping out # the original registration params logger.warn("Ignoring initial_device_display_name without password") - del body['initial_device_display_name'] + del body["initial_device_display_name"] session_id = self.auth_handler.get_session_id(body) registered_user_id = None @@ -336,8 +338,8 @@ class RegisterRestServlet(RestServlet): # FIXME: need a better error than "no auth flow found" for scenarios # where we required 3PID for registration but the user didn't give one - require_email = 'email' in self.hs.config.registrations_require_3pid - require_msisdn = 'msisdn' in self.hs.config.registrations_require_3pid + require_email = "email" in self.hs.config.registrations_require_3pid + require_msisdn = "msisdn" in self.hs.config.registrations_require_3pid show_msisdn = True if self.hs.config.disable_msisdn_registration: @@ -362,9 +364,9 @@ class RegisterRestServlet(RestServlet): if not require_email: flows.extend([[LoginType.RECAPTCHA, LoginType.MSISDN]]) # always let users provide both MSISDN & email - flows.extend([ - [LoginType.RECAPTCHA, LoginType.MSISDN, LoginType.EMAIL_IDENTITY], - ]) + flows.extend( + [[LoginType.RECAPTCHA, LoginType.MSISDN, LoginType.EMAIL_IDENTITY]] + ) else: # only support 3PIDless registration if no 3PIDs are required if not require_email and not require_msisdn: @@ -378,9 +380,7 @@ class RegisterRestServlet(RestServlet): if not require_email or require_msisdn: flows.extend([[LoginType.MSISDN]]) # always let users provide both MSISDN & email - flows.extend([ - [LoginType.MSISDN, LoginType.EMAIL_IDENTITY] - ]) + flows.extend([[LoginType.MSISDN, LoginType.EMAIL_IDENTITY]]) # Append m.login.terms to all flows if we're requiring consent if self.hs.config.user_consent_at_registration: @@ -410,21 +410,20 @@ class RegisterRestServlet(RestServlet): if auth_result: for login_type in [LoginType.EMAIL_IDENTITY, LoginType.MSISDN]: if login_type in auth_result: - medium = auth_result[login_type]['medium'] - address = auth_result[login_type]['address'] + medium = auth_result[login_type]["medium"] + address = auth_result[login_type]["address"] if not check_3pid_allowed(self.hs, medium, address): raise SynapseError( 403, - "Third party identifiers (email/phone numbers)" + - " are not authorized on this server", + "Third party identifiers (email/phone numbers)" + + " are not authorized on this server", Codes.THREEPID_DENIED, ) if registered_user_id is not None: logger.info( - "Already registered user ID %r for this session", - registered_user_id + "Already registered user ID %r for this session", registered_user_id ) # don't re-register the threepids registered = False @@ -451,11 +450,11 @@ class RegisterRestServlet(RestServlet): # the two activation emails, they would register the same 3pid twice. for login_type in [LoginType.EMAIL_IDENTITY, LoginType.MSISDN]: if login_type in auth_result: - medium = auth_result[login_type]['medium'] - address = auth_result[login_type]['address'] + medium = auth_result[login_type]["medium"] + address = auth_result[login_type]["address"] existingUid = yield self.store.get_user_id_by_threepid( - medium, address, + medium, address ) if existingUid is not None: @@ -520,7 +519,7 @@ class RegisterRestServlet(RestServlet): raise SynapseError(400, "Shared secret registration is not enabled") if not username: raise SynapseError( - 400, "username must be specified", errcode=Codes.BAD_JSON, + 400, "username must be specified", errcode=Codes.BAD_JSON ) # use the username from the original request rather than the @@ -541,12 +540,10 @@ class RegisterRestServlet(RestServlet): ).hexdigest() if not compare_digest(want_mac, got_mac): - raise SynapseError( - 403, "HMAC incorrect", - ) + raise SynapseError(403, "HMAC incorrect") (user_id, _) = yield self.registration_handler.register( - localpart=username, password=password, generate_token=False, + localpart=username, password=password, generate_token=False ) result = yield self._create_registration_details(user_id, body) @@ -565,21 +562,15 @@ class RegisterRestServlet(RestServlet): Returns: defer.Deferred: (object) dictionary for response from /register """ - result = { - "user_id": user_id, - "home_server": self.hs.hostname, - } + result = {"user_id": user_id, "home_server": self.hs.hostname} if not params.get("inhibit_login", False): device_id = params.get("device_id") initial_display_name = params.get("initial_device_display_name") device_id, access_token = yield self.registration_handler.register_device( - user_id, device_id, initial_display_name, is_guest=False, + user_id, device_id, initial_display_name, is_guest=False ) - result.update({ - "access_token": access_token, - "device_id": device_id, - }) + result.update({"access_token": access_token, "device_id": device_id}) defer.returnValue(result) @defer.inlineCallbacks @@ -587,9 +578,7 @@ class RegisterRestServlet(RestServlet): if not self.hs.config.allow_guest_access: raise SynapseError(403, "Guest access is disabled") user_id, _ = yield self.registration_handler.register( - generate_token=False, - make_guest=True, - address=address, + generate_token=False, make_guest=True, address=address ) # we don't allow guests to specify their own device_id, because @@ -597,15 +586,20 @@ class RegisterRestServlet(RestServlet): device_id = synapse.api.auth.GUEST_DEVICE_ID initial_display_name = params.get("initial_device_display_name") device_id, access_token = yield self.registration_handler.register_device( - user_id, device_id, initial_display_name, is_guest=True, + user_id, device_id, initial_display_name, is_guest=True ) - defer.returnValue((200, { - "user_id": user_id, - "device_id": device_id, - "access_token": access_token, - "home_server": self.hs.hostname, - })) + defer.returnValue( + ( + 200, + { + "user_id": user_id, + "device_id": device_id, + "access_token": access_token, + "home_server": self.hs.hostname, + }, + ) + ) def register_servlets(hs, http_server): diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py index f8f8742bdc..8e362782cc 100644 --- a/synapse/rest/client/v2_alpha/relations.py +++ b/synapse/rest/client/v2_alpha/relations.py @@ -32,7 +32,10 @@ from synapse.http.servlet import ( parse_string, ) from synapse.rest.client.transactions import HttpTransactionCache -from synapse.storage.relations import AggregationPaginationToken, RelationPaginationToken +from synapse.storage.relations import ( + AggregationPaginationToken, + RelationPaginationToken, +) from ._base import client_patterns diff --git a/synapse/rest/client/v2_alpha/report_event.py b/synapse/rest/client/v2_alpha/report_event.py index 10198662a9..e7578af804 100644 --- a/synapse/rest/client/v2_alpha/report_event.py +++ b/synapse/rest/client/v2_alpha/report_event.py @@ -33,9 +33,7 @@ logger = logging.getLogger(__name__) class ReportEventRestServlet(RestServlet): - PATTERNS = client_patterns( - "/rooms/(?P<room_id>[^/]*)/report/(?P<event_id>[^/]*)$" - ) + PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/report/(?P<event_id>[^/]*)$") def __init__(self, hs): super(ReportEventRestServlet, self).__init__() diff --git a/synapse/rest/client/v2_alpha/room_keys.py b/synapse/rest/client/v2_alpha/room_keys.py index 87779645f9..8d1b810565 100644 --- a/synapse/rest/client/v2_alpha/room_keys.py +++ b/synapse/rest/client/v2_alpha/room_keys.py @@ -129,22 +129,12 @@ class RoomKeysServlet(RestServlet): version = parse_string(request, "version") if session_id: - body = { - "sessions": { - session_id: body - } - } + body = {"sessions": {session_id: body}} if room_id: - body = { - "rooms": { - room_id: body - } - } + body = {"rooms": {room_id: body}} - yield self.e2e_room_keys_handler.upload_room_keys( - user_id, version, body - ) + yield self.e2e_room_keys_handler.upload_room_keys(user_id, version, body) defer.returnValue((200, {})) @defer.inlineCallbacks @@ -212,10 +202,10 @@ class RoomKeysServlet(RestServlet): if session_id: # If the client requests a specific session, but that session was # not backed up, then return an M_NOT_FOUND. - if room_keys['rooms'] == {}: + if room_keys["rooms"] == {}: raise NotFoundError("No room_keys found") else: - room_keys = room_keys['rooms'][room_id]['sessions'][session_id] + room_keys = room_keys["rooms"][room_id]["sessions"][session_id] elif room_id: # If the client requests all sessions from a room, but no sessions # are found, then return an empty result rather than an error, so @@ -223,10 +213,10 @@ class RoomKeysServlet(RestServlet): # empty result is valid. (Similarly if the client requests all # sessions from the backup, but in that case, room_keys is already # in the right format, so we don't need to do anything about it.) - if room_keys['rooms'] == {}: - room_keys = {'sessions': {}} + if room_keys["rooms"] == {}: + room_keys = {"sessions": {}} else: - room_keys = room_keys['rooms'][room_id] + room_keys = room_keys["rooms"][room_id] defer.returnValue((200, room_keys)) @@ -256,9 +246,7 @@ class RoomKeysServlet(RestServlet): class RoomKeysNewVersionServlet(RestServlet): - PATTERNS = client_patterns( - "/room_keys/version$" - ) + PATTERNS = client_patterns("/room_keys/version$") def __init__(self, hs): """ @@ -304,9 +292,7 @@ class RoomKeysNewVersionServlet(RestServlet): user_id = requester.user.to_string() info = parse_json_object_from_request(request) - new_version = yield self.e2e_room_keys_handler.create_version( - user_id, info - ) + new_version = yield self.e2e_room_keys_handler.create_version(user_id, info) defer.returnValue((200, {"version": new_version})) # we deliberately don't have a PUT /version, as these things really should @@ -314,9 +300,7 @@ class RoomKeysNewVersionServlet(RestServlet): class RoomKeysVersionServlet(RestServlet): - PATTERNS = client_patterns( - "/room_keys/version(/(?P<version>[^/]+))?$" - ) + PATTERNS = client_patterns("/room_keys/version(/(?P<version>[^/]+))?$") def __init__(self, hs): """ @@ -350,9 +334,7 @@ class RoomKeysVersionServlet(RestServlet): user_id = requester.user.to_string() try: - info = yield self.e2e_room_keys_handler.get_version_info( - user_id, version - ) + info = yield self.e2e_room_keys_handler.get_version_info(user_id, version) except SynapseError as e: if e.code == 404: raise SynapseError(404, "No backup found", Codes.NOT_FOUND) @@ -375,9 +357,7 @@ class RoomKeysVersionServlet(RestServlet): requester = yield self.auth.get_user_by_req(request, allow_guest=False) user_id = requester.user.to_string() - yield self.e2e_room_keys_handler.delete_version( - user_id, version - ) + yield self.e2e_room_keys_handler.delete_version(user_id, version) defer.returnValue((200, {})) @defer.inlineCallbacks @@ -407,11 +387,11 @@ class RoomKeysVersionServlet(RestServlet): info = parse_json_object_from_request(request) if version is None: - raise SynapseError(400, "No version specified to update", Codes.MISSING_PARAM) + raise SynapseError( + 400, "No version specified to update", Codes.MISSING_PARAM + ) - yield self.e2e_room_keys_handler.update_version( - user_id, version, info - ) + yield self.e2e_room_keys_handler.update_version(user_id, version, info) defer.returnValue((200, {})) diff --git a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py index c621a90fba..d7f7faa029 100644 --- a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py +++ b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py @@ -47,9 +47,10 @@ class RoomUpgradeRestServlet(RestServlet): Args: hs (synapse.server.HomeServer): """ + PATTERNS = client_patterns( # /rooms/$roomid/upgrade - "/rooms/(?P<room_id>[^/]*)/upgrade$", + "/rooms/(?P<room_id>[^/]*)/upgrade$" ) def __init__(self, hs): @@ -63,7 +64,7 @@ class RoomUpgradeRestServlet(RestServlet): requester = yield self._auth.get_user_by_req(request) content = parse_json_object_from_request(request) - assert_params_in_dict(content, ("new_version", )) + assert_params_in_dict(content, ("new_version",)) new_version = content["new_version"] if new_version not in KNOWN_ROOM_VERSIONS: @@ -77,9 +78,7 @@ class RoomUpgradeRestServlet(RestServlet): requester, room_id, new_version ) - ret = { - "replacement_room": new_room_id, - } + ret = {"replacement_room": new_room_id} defer.returnValue((200, ret)) diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py index 120a713361..78075b8fc0 100644 --- a/synapse/rest/client/v2_alpha/sendtodevice.py +++ b/synapse/rest/client/v2_alpha/sendtodevice.py @@ -28,7 +28,7 @@ logger = logging.getLogger(__name__) class SendToDeviceRestServlet(servlet.RestServlet): PATTERNS = client_patterns( - "/sendToDevice/(?P<message_type>[^/]*)/(?P<txn_id>[^/]*)$", + "/sendToDevice/(?P<message_type>[^/]*)/(?P<txn_id>[^/]*)$" ) def __init__(self, hs): diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index 148fc6c985..02d56dee6c 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -96,44 +96,42 @@ class SyncRestServlet(RestServlet): 400, "'from' is not a valid query parameter. Did you mean 'since'?" ) - requester = yield self.auth.get_user_by_req( - request, allow_guest=True - ) + requester = yield self.auth.get_user_by_req(request, allow_guest=True) user = requester.user device_id = requester.device_id timeout = parse_integer(request, "timeout", default=0) since = parse_string(request, "since") set_presence = parse_string( - request, "set_presence", default="online", - allowed_values=self.ALLOWED_PRESENCE + request, + "set_presence", + default="online", + allowed_values=self.ALLOWED_PRESENCE, ) filter_id = parse_string(request, "filter", default=None) full_state = parse_boolean(request, "full_state", default=False) logger.debug( "/sync: user=%r, timeout=%r, since=%r," - " set_presence=%r, filter_id=%r, device_id=%r" % ( - user, timeout, since, set_presence, filter_id, device_id - ) + " set_presence=%r, filter_id=%r, device_id=%r" + % (user, timeout, since, set_presence, filter_id, device_id) ) request_key = (user, timeout, since, filter_id, full_state, device_id) if filter_id: - if filter_id.startswith('{'): + if filter_id.startswith("{"): try: filter_object = json.loads(filter_id) - set_timeline_upper_limit(filter_object, - self.hs.config.filter_timeline_limit) + set_timeline_upper_limit( + filter_object, self.hs.config.filter_timeline_limit + ) except Exception: raise SynapseError(400, "Invalid filter JSON") self.filtering.check_valid_filter(filter_object) filter = FilterCollection(filter_object) else: - filter = yield self.filtering.get_user_filter( - user.localpart, filter_id - ) + filter = yield self.filtering.get_user_filter(user.localpart, filter_id) else: filter = DEFAULT_FILTER_COLLECTION @@ -156,15 +154,19 @@ class SyncRestServlet(RestServlet): affect_presence = set_presence != PresenceState.OFFLINE if affect_presence: - yield self.presence_handler.set_state(user, {"presence": set_presence}, True) + yield self.presence_handler.set_state( + user, {"presence": set_presence}, True + ) context = yield self.presence_handler.user_syncing( - user.to_string(), affect_presence=affect_presence, + user.to_string(), affect_presence=affect_presence ) with context: sync_result = yield self.sync_handler.wait_for_sync_for_user( - sync_config, since_token=since_token, timeout=timeout, - full_state=full_state + sync_config, + since_token=since_token, + timeout=timeout, + full_state=full_state, ) time_now = self.clock.time_msec() @@ -176,53 +178,54 @@ class SyncRestServlet(RestServlet): @defer.inlineCallbacks def encode_response(self, time_now, sync_result, access_token_id, filter): - if filter.event_format == 'client': + if filter.event_format == "client": event_formatter = format_event_for_client_v2_without_room_id - elif filter.event_format == 'federation': + elif filter.event_format == "federation": event_formatter = format_event_raw else: - raise Exception("Unknown event format %s" % (filter.event_format, )) + raise Exception("Unknown event format %s" % (filter.event_format,)) joined = yield self.encode_joined( - sync_result.joined, time_now, access_token_id, + sync_result.joined, + time_now, + access_token_id, filter.event_fields, event_formatter, ) invited = yield self.encode_invited( - sync_result.invited, time_now, access_token_id, - event_formatter, + sync_result.invited, time_now, access_token_id, event_formatter ) archived = yield self.encode_archived( - sync_result.archived, time_now, access_token_id, + sync_result.archived, + time_now, + access_token_id, filter.event_fields, event_formatter, ) - defer.returnValue({ - "account_data": {"events": sync_result.account_data}, - "to_device": {"events": sync_result.to_device}, - "device_lists": { - "changed": list(sync_result.device_lists.changed), - "left": list(sync_result.device_lists.left), - }, - "presence": SyncRestServlet.encode_presence( - sync_result.presence, time_now - ), - "rooms": { - "join": joined, - "invite": invited, - "leave": archived, - }, - "groups": { - "join": sync_result.groups.join, - "invite": sync_result.groups.invite, - "leave": sync_result.groups.leave, - }, - "device_one_time_keys_count": sync_result.device_one_time_keys_count, - "next_batch": sync_result.next_batch.to_string(), - }) + defer.returnValue( + { + "account_data": {"events": sync_result.account_data}, + "to_device": {"events": sync_result.to_device}, + "device_lists": { + "changed": list(sync_result.device_lists.changed), + "left": list(sync_result.device_lists.left), + }, + "presence": SyncRestServlet.encode_presence( + sync_result.presence, time_now + ), + "rooms": {"join": joined, "invite": invited, "leave": archived}, + "groups": { + "join": sync_result.groups.join, + "invite": sync_result.groups.invite, + "leave": sync_result.groups.leave, + }, + "device_one_time_keys_count": sync_result.device_one_time_keys_count, + "next_batch": sync_result.next_batch.to_string(), + } + ) @staticmethod def encode_presence(events, time_now): @@ -262,7 +265,11 @@ class SyncRestServlet(RestServlet): joined = {} for room in rooms: joined[room.room_id] = yield self.encode_room( - room, time_now, token_id, joined=True, only_fields=event_fields, + room, + time_now, + token_id, + joined=True, + only_fields=event_fields, event_formatter=event_formatter, ) @@ -290,7 +297,9 @@ class SyncRestServlet(RestServlet): invited = {} for room in rooms: invite = yield self._event_serializer.serialize_event( - room.invite, time_now, token_id=token_id, + room.invite, + time_now, + token_id=token_id, event_format=event_formatter, is_invite=True, ) @@ -298,9 +307,7 @@ class SyncRestServlet(RestServlet): invite["unsigned"] = unsigned invited_state = list(unsigned.pop("invite_room_state", [])) invited_state.append(invite) - invited[room.room_id] = { - "invite_state": {"events": invited_state} - } + invited[room.room_id] = {"invite_state": {"events": invited_state}} defer.returnValue(invited) @@ -327,7 +334,10 @@ class SyncRestServlet(RestServlet): joined = {} for room in rooms: joined[room.room_id] = yield self.encode_room( - room, time_now, token_id, joined=False, + room, + time_now, + token_id, + joined=False, only_fields=event_fields, event_formatter=event_formatter, ) @@ -336,8 +346,7 @@ class SyncRestServlet(RestServlet): @defer.inlineCallbacks def encode_room( - self, room, time_now, token_id, joined, - only_fields, event_formatter, + self, room, time_now, token_id, joined, only_fields, event_formatter ): """ Args: @@ -355,9 +364,11 @@ class SyncRestServlet(RestServlet): Returns: dict[str, object]: the room, encoded in our response format """ + def serialize(events): return self._event_serializer.serialize_events( - events, time_now=time_now, + events, + time_now=time_now, # We don't bundle "live" events, as otherwise clients # will end up double counting annotations. bundle_aggregations=False, @@ -377,7 +388,9 @@ class SyncRestServlet(RestServlet): if event.room_id != room.room_id: logger.warn( "Event %r is under room %r instead of %r", - event.event_id, room.room_id, event.room_id, + event.event_id, + room.room_id, + event.room_id, ) serialized_state = yield serialize(state_events) diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py index ebff7cff45..07b6ede603 100644 --- a/synapse/rest/client/v2_alpha/tags.py +++ b/synapse/rest/client/v2_alpha/tags.py @@ -29,9 +29,8 @@ class TagListServlet(RestServlet): """ GET /user/{user_id}/rooms/{room_id}/tags HTTP/1.1 """ - PATTERNS = client_patterns( - "/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags" - ) + + PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags") def __init__(self, hs): super(TagListServlet, self).__init__() @@ -54,6 +53,7 @@ class TagServlet(RestServlet): PUT /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1 DELETE /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1 """ + PATTERNS = client_patterns( "/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags/(?P<tag>[^/]*)" ) @@ -74,9 +74,7 @@ class TagServlet(RestServlet): max_id = yield self.store.add_tag_to_room(user_id, room_id, tag, body) - self.notifier.on_new_event( - "account_data_key", max_id, users=[user_id] - ) + self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) defer.returnValue((200, {})) @@ -88,9 +86,7 @@ class TagServlet(RestServlet): max_id = yield self.store.remove_tag_from_room(user_id, room_id, tag) - self.notifier.on_new_event( - "account_data_key", max_id, users=[user_id] - ) + self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) defer.returnValue((200, {})) diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py index e7a987466a..1e66662a05 100644 --- a/synapse/rest/client/v2_alpha/thirdparty.py +++ b/synapse/rest/client/v2_alpha/thirdparty.py @@ -57,7 +57,7 @@ class ThirdPartyProtocolServlet(RestServlet): yield self.auth.get_user_by_req(request, allow_guest=True) protocols = yield self.appservice_handler.get_3pe_protocols( - only_protocol=protocol, + only_protocol=protocol ) if protocol in protocols: defer.returnValue((200, protocols[protocol])) diff --git a/synapse/rest/client/v2_alpha/tokenrefresh.py b/synapse/rest/client/v2_alpha/tokenrefresh.py index 6c366142e1..2da0f55811 100644 --- a/synapse/rest/client/v2_alpha/tokenrefresh.py +++ b/synapse/rest/client/v2_alpha/tokenrefresh.py @@ -26,6 +26,7 @@ class TokenRefreshRestServlet(RestServlet): Exchanges refresh tokens for a pair of an access token and a new refresh token. """ + PATTERNS = client_patterns("/tokenrefresh") def __init__(self, hs): diff --git a/synapse/rest/client/v2_alpha/user_directory.py b/synapse/rest/client/v2_alpha/user_directory.py index 69e4efc47a..e19fb6d583 100644 --- a/synapse/rest/client/v2_alpha/user_directory.py +++ b/synapse/rest/client/v2_alpha/user_directory.py @@ -60,10 +60,7 @@ class UserDirectorySearchRestServlet(RestServlet): user_id = requester.user.to_string() if not self.hs.config.user_directory_search_enabled: - defer.returnValue((200, { - "limited": False, - "results": [], - })) + defer.returnValue((200, {"limited": False, "results": []})) body = parse_json_object_from_request(request) @@ -76,7 +73,7 @@ class UserDirectorySearchRestServlet(RestServlet): raise SynapseError(400, "`search_term` is required field") results = yield self.user_directory_handler.search_users( - user_id, search_term, limit, + user_id, search_term, limit ) defer.returnValue((200, results)) diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index babbf6a23c..0e09191632 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -25,27 +25,28 @@ class VersionsRestServlet(RestServlet): PATTERNS = [re.compile("^/_matrix/client/versions$")] def on_GET(self, request): - return (200, { - "versions": [ - # XXX: at some point we need to decide whether we need to include - # the previous version numbers, given we've defined r0.3.0 to be - # backwards compatible with r0.2.0. But need to check how - # conscientious we've been in compatibility, and decide whether the - # middle number is the major revision when at 0.X.Y (as opposed to - # X.Y.Z). And we need to decide whether it's fair to make clients - # parse the version string to figure out what's going on. - "r0.0.1", - "r0.1.0", - "r0.2.0", - "r0.3.0", - "r0.4.0", - "r0.5.0", - ], - # as per MSC1497: - "unstable_features": { - "m.lazy_load_members": True, - } - }) + return ( + 200, + { + "versions": [ + # XXX: at some point we need to decide whether we need to include + # the previous version numbers, given we've defined r0.3.0 to be + # backwards compatible with r0.2.0. But need to check how + # conscientious we've been in compatibility, and decide whether the + # middle number is the major revision when at 0.X.Y (as opposed to + # X.Y.Z). And we need to decide whether it's fair to make clients + # parse the version string to figure out what's going on. + "r0.0.1", + "r0.1.0", + "r0.2.0", + "r0.3.0", + "r0.4.0", + "r0.5.0", + ], + # as per MSC1497: + "unstable_features": {"m.lazy_load_members": True}, + }, + ) def register_servlets(http_server): |