diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 6dec862fec..bc629832d9 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -193,7 +193,7 @@ class Auth(object):
synapse.types.create_requester(user_id, app_service=app_service)
)
- access_token = get_access_token_from_request(
+ access_token = self.get_access_token_from_request(
request, self.TOKEN_NOT_FOUND_HTTP_STATUS
)
@@ -239,7 +239,7 @@ class Auth(object):
@defer.inlineCallbacks
def _get_appservice_user_id(self, request):
app_service = self.store.get_app_service_by_token(
- get_access_token_from_request(
+ self.get_access_token_from_request(
request, self.TOKEN_NOT_FOUND_HTTP_STATUS
)
)
@@ -513,7 +513,7 @@ class Auth(object):
def get_appservice_by_req(self, request):
try:
- token = get_access_token_from_request(
+ token = self.get_access_token_from_request(
request, self.TOKEN_NOT_FOUND_HTTP_STATUS
)
service = self.store.get_app_service_by_token(token)
@@ -673,67 +673,67 @@ class Auth(object):
" edit its room list entry"
)
+ @staticmethod
+ def has_access_token(request):
+ """Checks if the request has an access_token.
-def has_access_token(request):
- """Checks if the request has an access_token.
+ Returns:
+ bool: False if no access_token was given, True otherwise.
+ """
+ query_params = request.args.get("access_token")
+ auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
+ return bool(query_params) or bool(auth_headers)
- Returns:
- bool: False if no access_token was given, True otherwise.
- """
- query_params = request.args.get("access_token")
- auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
- return bool(query_params) or bool(auth_headers)
-
-
-def get_access_token_from_request(request, token_not_found_http_status=401):
- """Extracts the access_token from the request.
-
- Args:
- request: The http request.
- token_not_found_http_status(int): The HTTP status code to set in the
- AuthError if the token isn't found. This is used in some of the
- legacy APIs to change the status code to 403 from the default of
- 401 since some of the old clients depended on auth errors returning
- 403.
- Returns:
- str: The access_token
- Raises:
- AuthError: If there isn't an access_token in the request.
- """
+ @staticmethod
+ def get_access_token_from_request(request, token_not_found_http_status=401):
+ """Extracts the access_token from the request.
- auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
- query_params = request.args.get(b"access_token")
- if auth_headers:
- # Try the get the access_token from a "Authorization: Bearer"
- # header
- if query_params is not None:
- raise AuthError(
- token_not_found_http_status,
- "Mixing Authorization headers and access_token query parameters.",
- errcode=Codes.MISSING_TOKEN,
- )
- if len(auth_headers) > 1:
- raise AuthError(
- token_not_found_http_status,
- "Too many Authorization headers.",
- errcode=Codes.MISSING_TOKEN,
- )
- parts = auth_headers[0].split(" ")
- if parts[0] == "Bearer" and len(parts) == 2:
- return parts[1]
+ Args:
+ request: The http request.
+ token_not_found_http_status(int): The HTTP status code to set in the
+ AuthError if the token isn't found. This is used in some of the
+ legacy APIs to change the status code to 403 from the default of
+ 401 since some of the old clients depended on auth errors returning
+ 403.
+ Returns:
+ str: The access_token
+ Raises:
+ AuthError: If there isn't an access_token in the request.
+ """
+
+ auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
+ query_params = request.args.get(b"access_token")
+ if auth_headers:
+ # Try the get the access_token from a "Authorization: Bearer"
+ # header
+ if query_params is not None:
+ raise AuthError(
+ token_not_found_http_status,
+ "Mixing Authorization headers and access_token query parameters.",
+ errcode=Codes.MISSING_TOKEN,
+ )
+ if len(auth_headers) > 1:
+ raise AuthError(
+ token_not_found_http_status,
+ "Too many Authorization headers.",
+ errcode=Codes.MISSING_TOKEN,
+ )
+ parts = auth_headers[0].split(" ")
+ if parts[0] == "Bearer" and len(parts) == 2:
+ return parts[1]
+ else:
+ raise AuthError(
+ token_not_found_http_status,
+ "Invalid Authorization header.",
+ errcode=Codes.MISSING_TOKEN,
+ )
else:
- raise AuthError(
- token_not_found_http_status,
- "Invalid Authorization header.",
- errcode=Codes.MISSING_TOKEN,
- )
- else:
- # Try to get the access_token from the query params.
- if not query_params:
- raise AuthError(
- token_not_found_http_status,
- "Missing access token.",
- errcode=Codes.MISSING_TOKEN
- )
+ # Try to get the access_token from the query params.
+ if not query_params:
+ raise AuthError(
+ token_not_found_http_status,
+ "Missing access token.",
+ errcode=Codes.MISSING_TOKEN
+ )
- return query_params[0]
+ return query_params[0]
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index f0c7a06718..c11798093d 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -23,7 +23,7 @@ from synapse.api.errors import Codes, SynapseError
from synapse.crypto.event_signing import check_event_content_hash
from synapse.events import FrozenEvent
from synapse.events.utils import prune_event
-from synapse.http.servlet import assert_params_in_request
+from synapse.http.servlet import assert_params_in_dict
from synapse.util import logcontext, unwrapFirstError
logger = logging.getLogger(__name__)
@@ -199,7 +199,7 @@ def event_from_pdu_json(pdu_json, outlier=False):
"""
# we could probably enforce a bunch of other fields here (room_id, sender,
# origin, etc etc)
- assert_params_in_request(pdu_json, ('event_id', 'type', 'depth'))
+ assert_params_in_dict(pdu_json, ('event_id', 'type', 'depth'))
depth = pdu_json['depth']
if not isinstance(depth, six.integer_types):
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 1abd45297b..828229f5c3 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -38,7 +38,7 @@ REMOTE_ROOM_LIST_POLL_INTERVAL = 60 * 1000
# This is used to indicate we should only return rooms published to the main list.
-EMTPY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None)
+EMPTY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None)
class RoomListHandler(BaseHandler):
@@ -50,7 +50,7 @@ class RoomListHandler(BaseHandler):
def get_local_public_room_list(self, limit=None, since_token=None,
search_filter=None,
- network_tuple=EMTPY_THIRD_PARTY_ID,):
+ network_tuple=EMPTY_THIRD_PARTY_ID,):
"""Generate a local public room list.
There are multiple different lists: the main one plus one per third
@@ -87,7 +87,7 @@ class RoomListHandler(BaseHandler):
@defer.inlineCallbacks
def _get_public_room_list(self, limit=None, since_token=None,
search_filter=None,
- network_tuple=EMTPY_THIRD_PARTY_ID,):
+ network_tuple=EMPTY_THIRD_PARTY_ID,):
if since_token and since_token != "END":
since_token = RoomListNextBatch.from_token(since_token)
else:
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index cf6723563a..882816dc8f 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -206,7 +206,7 @@ def parse_json_object_from_request(request, allow_empty_body=False):
return content
-def assert_params_in_request(body, required):
+def assert_params_in_dict(body, required):
absent = []
for k in required:
if k not in body:
diff --git a/synapse/http/site.py b/synapse/http/site.py
index 21e26f9c5e..5fd30a4c2c 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -20,7 +20,7 @@ from twisted.web.server import Request, Site
from synapse.http import redact_uri
from synapse.http.request_metrics import RequestMetrics
-from synapse.util.logcontext import LoggingContext, ContextResourceUsage
+from synapse.util.logcontext import ContextResourceUsage, LoggingContext
logger = logging.getLogger(__name__)
@@ -42,9 +42,10 @@ class SynapseRequest(Request):
which is handling the request, and returns a context manager.
"""
- def __init__(self, site, *args, **kw):
- Request.__init__(self, *args, **kw)
+ def __init__(self, site, channel, *args, **kw):
+ Request.__init__(self, channel, *args, **kw)
self.site = site
+ self._channel = channel
self.authenticated_entity = None
self.start_time = 0
diff --git a/synapse/rest/client/transactions.py b/synapse/rest/client/transactions.py
index 7c01b438cb..00b1b3066e 100644
--- a/synapse/rest/client/transactions.py
+++ b/synapse/rest/client/transactions.py
@@ -17,38 +17,20 @@
to ensure idempotency when performing PUTs using the REST API."""
import logging
-from synapse.api.auth import get_access_token_from_request
from synapse.util.async import ObservableDeferred
from synapse.util.logcontext import make_deferred_yieldable, run_in_background
logger = logging.getLogger(__name__)
-
-def get_transaction_key(request):
- """A helper function which returns a transaction key that can be used
- with TransactionCache for idempotent requests.
-
- Idempotency is based on the returned key being the same for separate
- requests to the same endpoint. The key is formed from the HTTP request
- path and the access_token for the requesting user.
-
- Args:
- request (twisted.web.http.Request): The incoming request. Must
- contain an access_token.
- Returns:
- str: A transaction key
- """
- token = get_access_token_from_request(request)
- return request.path + "/" + token
-
-
CLEANUP_PERIOD_MS = 1000 * 60 * 30 # 30 mins
class HttpTransactionCache(object):
- def __init__(self, clock):
- self.clock = clock
+ def __init__(self, hs):
+ self.hs = hs
+ self.auth = self.hs.get_auth()
+ self.clock = self.hs.get_clock()
self.transactions = {
# $txn_key: (ObservableDeferred<(res_code, res_json_body)>, timestamp)
}
@@ -56,6 +38,23 @@ class HttpTransactionCache(object):
# for at *LEAST* 30 mins, and at *MOST* 60 mins.
self.cleaner = self.clock.looping_call(self._cleanup, CLEANUP_PERIOD_MS)
+ def _get_transaction_key(self, request):
+ """A helper function which returns a transaction key that can be used
+ with TransactionCache for idempotent requests.
+
+ Idempotency is based on the returned key being the same for separate
+ requests to the same endpoint. The key is formed from the HTTP request
+ path and the access_token for the requesting user.
+
+ Args:
+ request (twisted.web.http.Request): The incoming request. Must
+ contain an access_token.
+ Returns:
+ str: A transaction key
+ """
+ token = self.auth.get_access_token_from_request(request)
+ return request.path + "/" + token
+
def fetch_or_execute_request(self, request, fn, *args, **kwargs):
"""A helper function for fetch_or_execute which extracts
a transaction key from the given request.
@@ -64,7 +63,7 @@ class HttpTransactionCache(object):
fetch_or_execute
"""
return self.fetch_or_execute(
- get_transaction_key(request), fn, *args, **kwargs
+ self._get_transaction_key(request), fn, *args, **kwargs
)
def fetch_or_execute(self, txn_key, fn, *args, **kwargs):
diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py
index 2b091d61a5..2dc50e582b 100644
--- a/synapse/rest/client/v1/admin.py
+++ b/synapse/rest/client/v1/admin.py
@@ -22,7 +22,12 @@ from twisted.internet import defer
from synapse.api.constants import Membership
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
-from synapse.http.servlet import parse_json_object_from_request
+from synapse.http.servlet import (
+ assert_params_in_dict,
+ parse_integer,
+ parse_json_object_from_request,
+ parse_string,
+)
from synapse.types import UserID, create_requester
from .base import ClientV1RestServlet, client_path_patterns
@@ -98,16 +103,8 @@ class PurgeMediaCacheRestServlet(ClientV1RestServlet):
if not is_admin:
raise AuthError(403, "You are not a server admin")
- before_ts = request.args.get("before_ts", None)
- if not before_ts:
- raise SynapseError(400, "Missing 'before_ts' arg")
-
- logger.info("before_ts: %r", before_ts[0])
-
- try:
- before_ts = int(before_ts[0])
- except Exception:
- raise SynapseError(400, "Invalid 'before_ts' arg")
+ before_ts = parse_integer(request, "before_ts", required=True)
+ logger.info("before_ts: %r", before_ts)
ret = yield self.media_repository.delete_old_remote_media(before_ts)
@@ -300,10 +297,8 @@ class ShutdownRoomRestServlet(ClientV1RestServlet):
raise AuthError(403, "You are not a server admin")
content = parse_json_object_from_request(request)
-
- new_room_user_id = content.get("new_room_user_id")
- if not new_room_user_id:
- raise SynapseError(400, "Please provide field `new_room_user_id`")
+ assert_params_in_dict(content, ["new_room_user_id"])
+ new_room_user_id = content["new_room_user_id"]
room_creator_requester = create_requester(new_room_user_id)
@@ -464,9 +459,8 @@ class ResetPasswordRestServlet(ClientV1RestServlet):
raise AuthError(403, "You are not a server admin")
params = parse_json_object_from_request(request)
+ assert_params_in_dict(params, ["new_password"])
new_password = params['new_password']
- if not new_password:
- raise SynapseError(400, "Missing 'new_password' arg")
logger.info("new_password: %r", new_password)
@@ -514,12 +508,9 @@ class GetUsersPaginatedRestServlet(ClientV1RestServlet):
raise SynapseError(400, "Can only users a local user")
order = "name" # order by name in user table
- start = request.args.get("start")[0]
- limit = request.args.get("limit")[0]
- if not limit:
- raise SynapseError(400, "Missing 'limit' arg")
- if not start:
- raise SynapseError(400, "Missing 'start' arg")
+ start = parse_integer(request, "start", required=True)
+ limit = parse_integer(request, "limit", required=True)
+
logger.info("limit: %s, start: %s", limit, start)
ret = yield self.handlers.admin_handler.get_users_paginate(
@@ -551,12 +542,9 @@ class GetUsersPaginatedRestServlet(ClientV1RestServlet):
order = "name" # order by name in user table
params = parse_json_object_from_request(request)
+ assert_params_in_dict(params, ["limit", "start"])
limit = params['limit']
start = params['start']
- if not limit:
- raise SynapseError(400, "Missing 'limit' arg")
- if not start:
- raise SynapseError(400, "Missing 'start' arg")
logger.info("limit: %s, start: %s", limit, start)
ret = yield self.handlers.admin_handler.get_users_paginate(
@@ -604,10 +592,7 @@ class SearchUsersRestServlet(ClientV1RestServlet):
if not self.hs.is_mine(target_user):
raise SynapseError(400, "Can only users a local user")
- term = request.args.get("term")[0]
- if not term:
- raise SynapseError(400, "Missing 'term' arg")
-
+ term = parse_string(request, "term", required=True)
logger.info("term: %s ", term)
ret = yield self.handlers.admin_handler.search_users(
diff --git a/synapse/rest/client/v1/base.py b/synapse/rest/client/v1/base.py
index dde02328c3..c77d7aba68 100644
--- a/synapse/rest/client/v1/base.py
+++ b/synapse/rest/client/v1/base.py
@@ -62,4 +62,4 @@ class ClientV1RestServlet(RestServlet):
self.hs = hs
self.builder_factory = hs.get_event_builder_factory()
self.auth = hs.get_auth()
- self.txns = HttpTransactionCache(hs.get_clock())
+ self.txns = HttpTransactionCache(hs)
diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py
index 4fdbb83815..69dcd618cb 100644
--- a/synapse/rest/client/v1/directory.py
+++ b/synapse/rest/client/v1/directory.py
@@ -52,15 +52,14 @@ class ClientDirectoryServer(ClientV1RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, room_alias):
+ room_alias = RoomAlias.from_string(room_alias)
+
content = parse_json_object_from_request(request)
if "room_id" not in content:
- raise SynapseError(400, "Missing room_id key",
+ raise SynapseError(400, 'Missing params: ["room_id"]',
errcode=Codes.BAD_JSON)
logger.debug("Got content: %s", content)
-
- room_alias = RoomAlias.from_string(room_alias)
-
logger.debug("Got room name: %s", room_alias.to_string())
room_id = content["room_id"]
diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py
index fbe8cb2023..fd5f85b53e 100644
--- a/synapse/rest/client/v1/initial_sync.py
+++ b/synapse/rest/client/v1/initial_sync.py
@@ -15,6 +15,7 @@
from twisted.internet import defer
+from synapse.http.servlet import parse_boolean
from synapse.streams.config import PaginationConfig
from .base import ClientV1RestServlet, client_path_patterns
@@ -33,7 +34,7 @@ class InitialSyncRestServlet(ClientV1RestServlet):
requester = yield self.auth.get_user_by_req(request)
as_client_event = "raw" not in request.args
pagination_config = PaginationConfig.from_request(request)
- include_archived = request.args.get("archived", None) == ["true"]
+ include_archived = parse_boolean(request, "archived", default=False)
content = yield self.initial_sync_handler.snapshot_all_rooms(
user_id=requester.user.to_string(),
pagin_config=pagination_config,
diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py
index 05a8ecfcd8..430c692336 100644
--- a/synapse/rest/client/v1/logout.py
+++ b/synapse/rest/client/v1/logout.py
@@ -17,7 +17,6 @@ import logging
from twisted.internet import defer
-from synapse.api.auth import get_access_token_from_request
from synapse.api.errors import AuthError
from .base import ClientV1RestServlet, client_path_patterns
@@ -51,7 +50,7 @@ class LogoutRestServlet(ClientV1RestServlet):
if requester.device_id is None:
# the acccess token wasn't associated with a device.
# Just delete the access token
- access_token = get_access_token_from_request(request)
+ access_token = self._auth.get_access_token_from_request(request)
yield self._auth_handler.delete_access_token(access_token)
else:
yield self._device_handler.delete_device(
diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py
index 0df7ce570f..6e95d9bec2 100644
--- a/synapse/rest/client/v1/push_rule.py
+++ b/synapse/rest/client/v1/push_rule.py
@@ -21,7 +21,7 @@ from synapse.api.errors import (
SynapseError,
UnrecognizedRequestError,
)
-from synapse.http.servlet import parse_json_value_from_request
+from synapse.http.servlet import parse_json_value_from_request, parse_string
from synapse.push.baserules import BASE_RULE_IDS
from synapse.push.clientformat import format_push_rules_for_user
from synapse.push.rulekinds import PRIORITY_CLASS_MAP
@@ -75,13 +75,13 @@ class PushRuleRestServlet(ClientV1RestServlet):
except InvalidRuleException as e:
raise SynapseError(400, e.message)
- before = request.args.get("before", None)
+ before = parse_string(request, "before")
if before:
- before = _namespaced_rule_id(spec, before[0])
+ before = _namespaced_rule_id(spec, before)
- after = request.args.get("after", None)
+ after = parse_string(request, "after")
if after:
- after = _namespaced_rule_id(spec, after[0])
+ after = _namespaced_rule_id(spec, after)
try:
yield self.store.add_push_rule(
diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py
index 1581f88db5..182a68b1e2 100644
--- a/synapse/rest/client/v1/pusher.py
+++ b/synapse/rest/client/v1/pusher.py
@@ -21,6 +21,7 @@ from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.http.server import finish_request
from synapse.http.servlet import (
RestServlet,
+ assert_params_in_dict,
parse_json_object_from_request,
parse_string,
)
@@ -91,15 +92,11 @@ class PushersSetRestServlet(ClientV1RestServlet):
)
defer.returnValue((200, {}))
- reqd = ['kind', 'app_id', 'app_display_name',
- 'device_display_name', 'pushkey', 'lang', 'data']
- missing = []
- for i in reqd:
- if i not in content:
- missing.append(i)
- if len(missing):
- raise SynapseError(400, "Missing parameters: " + ','.join(missing),
- errcode=Codes.MISSING_PARAM)
+ assert_params_in_dict(
+ content,
+ ['kind', 'app_id', 'app_display_name',
+ 'device_display_name', 'pushkey', 'lang', 'data']
+ )
logger.debug("set pushkey %s to kind %s", content['pushkey'], content['kind'])
logger.debug("Got pushers request with body: %r", content)
@@ -148,7 +145,7 @@ class PushersRemoveRestServlet(RestServlet):
SUCCESS_HTML = "<html><body>You have been unsubscribed</body><html>"
def __init__(self, hs):
- super(RestServlet, self).__init__()
+ super(PushersRemoveRestServlet, self).__init__()
self.hs = hs
self.notifier = hs.get_notifier()
self.auth = hs.get_auth()
diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py
index 3ce5f8b726..25a143af8d 100644
--- a/synapse/rest/client/v1/register.py
+++ b/synapse/rest/client/v1/register.py
@@ -18,15 +18,12 @@ import hmac
import logging
from hashlib import sha1
-from six import string_types
-
from twisted.internet import defer
import synapse.util.stringutils as stringutils
-from synapse.api.auth import get_access_token_from_request
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, SynapseError
-from synapse.http.servlet import parse_json_object_from_request
+from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_request
from synapse.types import create_requester
from .base import ClientV1RestServlet, client_path_patterns
@@ -67,6 +64,7 @@ class RegisterRestServlet(ClientV1RestServlet):
# TODO: persistent storage
self.sessions = {}
self.enable_registration = hs.config.enable_registration
+ self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
self.handlers = hs.get_handlers()
@@ -124,8 +122,7 @@ class RegisterRestServlet(ClientV1RestServlet):
session = (register_json["session"]
if "session" in register_json else None)
login_type = None
- if "type" not in register_json:
- raise SynapseError(400, "Missing 'type' key.")
+ assert_params_in_dict(register_json, ["type"])
try:
login_type = register_json["type"]
@@ -310,11 +307,9 @@ class RegisterRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def _do_app_service(self, request, register_json, session):
- as_token = get_access_token_from_request(request)
-
- if "user" not in register_json:
- raise SynapseError(400, "Expected 'user' key.")
+ as_token = self.auth.get_access_token_from_request(request)
+ assert_params_in_dict(register_json, ["user"])
user_localpart = register_json["user"].encode("utf-8")
handler = self.handlers.registration_handler
@@ -331,12 +326,7 @@ class RegisterRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def _do_shared_secret(self, request, register_json, session):
- if not isinstance(register_json.get("mac", None), string_types):
- raise SynapseError(400, "Expected mac.")
- if not isinstance(register_json.get("user", None), string_types):
- raise SynapseError(400, "Expected 'user' key.")
- if not isinstance(register_json.get("password", None), string_types):
- raise SynapseError(400, "Expected 'password' key.")
+ assert_params_in_dict(register_json, ["mac", "user", "password"])
if not self.hs.config.registration_shared_secret:
raise SynapseError(400, "Shared secret registration is not enabled")
@@ -400,7 +390,7 @@ class CreateUserRestServlet(ClientV1RestServlet):
def on_POST(self, request):
user_json = parse_json_object_from_request(request)
- access_token = get_access_token_from_request(request)
+ access_token = self.auth.get_access_token_from_request(request)
app_service = self.store.get_app_service_by_token(
access_token
)
@@ -419,11 +409,7 @@ class CreateUserRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def _do_create(self, requester, user_json):
- if "localpart" not in user_json:
- raise SynapseError(400, "Expected 'localpart' key.")
-
- if "displayname" not in user_json:
- raise SynapseError(400, "Expected 'displayname' key.")
+ assert_params_in_dict(user_json, ["localpart", "displayname"])
localpart = user_json["localpart"].encode("utf-8")
displayname = user_json["displayname"].encode("utf-8")
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 2470db52ba..3d62447854 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -28,6 +28,7 @@ from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.api.filtering import Filter
from synapse.events.utils import format_event_for_client_v2, serialize_event
from synapse.http.servlet import (
+ assert_params_in_dict,
parse_integer,
parse_json_object_from_request,
parse_string,
@@ -435,9 +436,9 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
request, default_limit=10,
)
as_client_event = "raw" not in request.args
- filter_bytes = request.args.get("filter", None)
+ filter_bytes = parse_string(request, "filter")
if filter_bytes:
- filter_json = urlparse.unquote(filter_bytes[-1]).decode("UTF-8")
+ filter_json = urlparse.unquote(filter_bytes).decode("UTF-8")
event_filter = Filter(json.loads(filter_json))
else:
event_filter = None
@@ -530,7 +531,7 @@ class RoomEventContextServlet(ClientV1RestServlet):
def on_GET(self, request, room_id, event_id):
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
- limit = int(request.args.get("limit", [10])[0])
+ limit = parse_integer(request, "limit", default=10)
results = yield self.handlers.room_context_handler.get_event_context(
requester.user,
@@ -636,8 +637,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
target = requester.user
if membership_action in ["invite", "ban", "unban", "kick"]:
- if "user_id" not in content:
- raise SynapseError(400, "Missing user_id key.")
+ assert_params_in_dict(content, ["user_id"])
target = UserID.from_string(content["user_id"])
event_content = None
@@ -764,7 +764,7 @@ class SearchRestServlet(ClientV1RestServlet):
content = parse_json_object_from_request(request)
- batch = request.args.get("next_batch", [None])[0]
+ batch = parse_string(request, "next_batch")
results = yield self.handlers.search_handler.search(
requester.user,
content,
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index 528c1f43f9..eeae466d82 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -20,12 +20,11 @@ from six.moves import http_client
from twisted.internet import defer
-from synapse.api.auth import has_access_token
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, SynapseError
from synapse.http.servlet import (
RestServlet,
- assert_params_in_request,
+ assert_params_in_dict,
parse_json_object_from_request,
)
from synapse.util.msisdn import phone_number_to_msisdn
@@ -48,7 +47,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
def on_POST(self, request):
body = parse_json_object_from_request(request)
- assert_params_in_request(body, [
+ assert_params_in_dict(body, [
'id_server', 'client_secret', 'email', 'send_attempt'
])
@@ -81,7 +80,7 @@ class MsisdnPasswordRequestTokenRestServlet(RestServlet):
def on_POST(self, request):
body = parse_json_object_from_request(request)
- assert_params_in_request(body, [
+ assert_params_in_dict(body, [
'id_server', 'client_secret',
'country', 'phone_number', 'send_attempt',
])
@@ -130,7 +129,7 @@ class PasswordRestServlet(RestServlet):
#
# In the second case, we require a password to confirm their identity.
- if has_access_token(request):
+ if self.auth.has_access_token(request):
requester = yield self.auth.get_user_by_req(request)
params = yield self.auth_handler.validate_user_via_ui_auth(
requester, body, self.hs.get_ip_from_request(request),
@@ -160,11 +159,10 @@ class PasswordRestServlet(RestServlet):
raise SynapseError(404, "Email address not found", Codes.NOT_FOUND)
user_id = threepid_user_id
else:
- logger.error("Auth succeeded but no known type!", result.keys())
+ logger.error("Auth succeeded but no known type! %r", result.keys())
raise SynapseError(500, "", Codes.UNKNOWN)
- if 'new_password' not in params:
- raise SynapseError(400, "", Codes.MISSING_PARAM)
+ assert_params_in_dict(params, ["new_password"])
new_password = params['new_password']
yield self._set_password_handler.set_password(
@@ -229,15 +227,10 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
@defer.inlineCallbacks
def on_POST(self, request):
body = parse_json_object_from_request(request)
-
- required = ['id_server', 'client_secret', 'email', 'send_attempt']
- absent = []
- for k in required:
- if k not in body:
- absent.append(k)
-
- if absent:
- raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
+ assert_params_in_dict(
+ body,
+ ['id_server', 'client_secret', 'email', 'send_attempt'],
+ )
if not check_3pid_allowed(self.hs, "email", body['email']):
raise SynapseError(
@@ -267,18 +260,10 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
@defer.inlineCallbacks
def on_POST(self, request):
body = parse_json_object_from_request(request)
-
- required = [
+ assert_params_in_dict(body, [
'id_server', 'client_secret',
'country', 'phone_number', 'send_attempt',
- ]
- absent = []
- for k in required:
- if k not in body:
- absent.append(k)
-
- if absent:
- raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
+ ])
msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
@@ -373,15 +358,7 @@ class ThreepidDeleteRestServlet(RestServlet):
@defer.inlineCallbacks
def on_POST(self, request):
body = parse_json_object_from_request(request)
-
- required = ['medium', 'address']
- absent = []
- for k in required:
- if k not in body:
- absent.append(k)
-
- if absent:
- raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
+ assert_params_in_dict(body, ['medium', 'address'])
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py
index 09f6a8efe3..9b75bb1377 100644
--- a/synapse/rest/client/v2_alpha/devices.py
+++ b/synapse/rest/client/v2_alpha/devices.py
@@ -18,14 +18,18 @@ import logging
from twisted.internet import defer
from synapse.api import errors
-from synapse.http import servlet
+from synapse.http.servlet import (
+ RestServlet,
+ assert_params_in_dict,
+ parse_json_object_from_request,
+)
from ._base import client_v2_patterns, interactive_auth_handler
logger = logging.getLogger(__name__)
-class DevicesRestServlet(servlet.RestServlet):
+class DevicesRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/devices$", v2_alpha=False)
def __init__(self, hs):
@@ -47,7 +51,7 @@ class DevicesRestServlet(servlet.RestServlet):
defer.returnValue((200, {"devices": devices}))
-class DeleteDevicesRestServlet(servlet.RestServlet):
+class DeleteDevicesRestServlet(RestServlet):
"""
API for bulk deletion of devices. Accepts a JSON object with a devices
key which lists the device_ids to delete. Requires user interactive auth.
@@ -67,19 +71,17 @@ class DeleteDevicesRestServlet(servlet.RestServlet):
requester = yield self.auth.get_user_by_req(request)
try:
- body = servlet.parse_json_object_from_request(request)
+ body = parse_json_object_from_request(request)
except errors.SynapseError as e:
if e.errcode == errors.Codes.NOT_JSON:
- # deal with older clients which didn't pass a J*DELETESON dict
+ # DELETE
+ # deal with older clients which didn't pass a JSON dict
# the same as those that pass an empty dict
body = {}
else:
raise e
- if 'devices' not in body:
- raise errors.SynapseError(
- 400, "No devices supplied", errcode=errors.Codes.MISSING_PARAM
- )
+ assert_params_in_dict(body, ["devices"])
yield self.auth_handler.validate_user_via_ui_auth(
requester, body, self.hs.get_ip_from_request(request),
@@ -92,7 +94,7 @@ class DeleteDevicesRestServlet(servlet.RestServlet):
defer.returnValue((200, {}))
-class DeviceRestServlet(servlet.RestServlet):
+class DeviceRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$", v2_alpha=False)
def __init__(self, hs):
@@ -121,7 +123,7 @@ class DeviceRestServlet(servlet.RestServlet):
requester = yield self.auth.get_user_by_req(request)
try:
- body = servlet.parse_json_object_from_request(request)
+ body = parse_json_object_from_request(request)
except errors.SynapseError as e:
if e.errcode == errors.Codes.NOT_JSON:
@@ -144,7 +146,7 @@ class DeviceRestServlet(servlet.RestServlet):
def on_PUT(self, request, device_id):
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
- body = servlet.parse_json_object_from_request(request)
+ body = parse_json_object_from_request(request)
yield self.device_handler.update_device(
requester.user.to_string(),
device_id,
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 896650d5a5..d6cf915d86 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -24,12 +24,11 @@ from twisted.internet import defer
import synapse
import synapse.types
-from synapse.api.auth import get_access_token_from_request, has_access_token
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, SynapseError, UnrecognizedRequestError
from synapse.http.servlet import (
RestServlet,
- assert_params_in_request,
+ assert_params_in_dict,
parse_json_object_from_request,
parse_string,
)
@@ -69,7 +68,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
def on_POST(self, request):
body = parse_json_object_from_request(request)
- assert_params_in_request(body, [
+ assert_params_in_dict(body, [
'id_server', 'client_secret', 'email', 'send_attempt'
])
@@ -105,7 +104,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
def on_POST(self, request):
body = parse_json_object_from_request(request)
- assert_params_in_request(body, [
+ assert_params_in_dict(body, [
'id_server', 'client_secret',
'country', 'phone_number',
'send_attempt',
@@ -224,7 +223,7 @@ class RegisterRestServlet(RestServlet):
desired_username = body['username']
appservice = None
- if has_access_token(request):
+ if self.auth.has_access_token(request):
appservice = yield self.auth.get_appservice_by_req(request)
# fork off as soon as possible for ASes and shared secret auth which
@@ -242,7 +241,7 @@ class RegisterRestServlet(RestServlet):
# because the IRC bridges rely on being able to register stupid
# IDs.
- access_token = get_access_token_from_request(request)
+ access_token = self.auth.get_access_token_from_request(request)
if isinstance(desired_username, string_types):
result = yield self._do_appservice_registration(
@@ -387,9 +386,7 @@ class RegisterRestServlet(RestServlet):
add_msisdn = False
else:
# NB: This may be from the auth handler and NOT from the POST
- if 'password' not in params:
- raise SynapseError(400, "Missing password.",
- Codes.MISSING_PARAM)
+ assert_params_in_dict(params, ["password"])
desired_username = params.get("username", None)
new_password = params.get("password", None)
@@ -566,11 +563,14 @@ class RegisterRestServlet(RestServlet):
Returns:
defer.Deferred:
"""
- reqd = ('medium', 'address', 'validated_at')
- if any(x not in threepid for x in reqd):
- # This will only happen if the ID server returns a malformed response
- logger.info("Can't add incomplete 3pid")
- defer.returnValue()
+ try:
+ assert_params_in_dict(threepid, ['medium', 'address', 'validated_at'])
+ except SynapseError as ex:
+ if ex.errcode == Codes.MISSING_PARAM:
+ # This will only happen if the ID server returns a malformed response
+ logger.info("Can't add incomplete 3pid")
+ defer.returnValue(None)
+ raise
yield self.auth_handler.add_threepid(
user_id,
@@ -643,7 +643,7 @@ class RegisterRestServlet(RestServlet):
@defer.inlineCallbacks
def _do_guest_registration(self, params):
if not self.hs.config.allow_guest_access:
- defer.returnValue((403, "Guest access is disabled"))
+ raise SynapseError(403, "Guest access is disabled")
user_id, _ = yield self.registration_handler.register(
generate_token=False,
make_guest=True
diff --git a/synapse/rest/client/v2_alpha/report_event.py b/synapse/rest/client/v2_alpha/report_event.py
index 0cc2a71c3b..95d2a71ec2 100644
--- a/synapse/rest/client/v2_alpha/report_event.py
+++ b/synapse/rest/client/v2_alpha/report_event.py
@@ -23,7 +23,7 @@ from twisted.internet import defer
from synapse.api.errors import Codes, SynapseError
from synapse.http.servlet import (
RestServlet,
- assert_params_in_request,
+ assert_params_in_dict,
parse_json_object_from_request,
)
@@ -50,7 +50,7 @@ class ReportEventRestServlet(RestServlet):
user_id = requester.user.to_string()
body = parse_json_object_from_request(request)
- assert_params_in_request(body, ("reason", "score"))
+ assert_params_in_dict(body, ("reason", "score"))
if not isinstance(body["reason"], string_types):
raise SynapseError(
diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py
index 90bdb1db15..a9e9a47a0b 100644
--- a/synapse/rest/client/v2_alpha/sendtodevice.py
+++ b/synapse/rest/client/v2_alpha/sendtodevice.py
@@ -40,7 +40,7 @@ class SendToDeviceRestServlet(servlet.RestServlet):
super(SendToDeviceRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
- self.txns = HttpTransactionCache(hs.get_clock())
+ self.txns = HttpTransactionCache(hs)
self.device_message_handler = hs.get_device_message_handler()
def on_PUT(self, request, message_type, txn_id):
diff --git a/synapse/rest/media/v1/identicon_resource.py b/synapse/rest/media/v1/identicon_resource.py
index a2e391415f..bdbd8d50dd 100644
--- a/synapse/rest/media/v1/identicon_resource.py
+++ b/synapse/rest/media/v1/identicon_resource.py
@@ -16,6 +16,8 @@ from pydenticon import Generator
from twisted.web.resource import Resource
+from synapse.http.servlet import parse_integer
+
FOREGROUND = [
"rgb(45,79,255)",
"rgb(254,180,44)",
@@ -56,8 +58,8 @@ class IdenticonResource(Resource):
def render_GET(self, request):
name = "/".join(request.postpath)
- width = int(request.args.get("width", [96])[0])
- height = int(request.args.get("height", [96])[0])
+ width = parse_integer(request, "width", default=96)
+ height = parse_integer(request, "height", default=96)
identicon_bytes = self.generate_identicon(name, width, height)
request.setHeader(b"Content-Type", b"image/png")
request.setHeader(
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 4e3a18ce08..b70b15c4c2 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -40,6 +40,7 @@ from synapse.http.server import (
respond_with_json_bytes,
wrap_json_request_handler,
)
+from synapse.http.servlet import parse_integer, parse_string
from synapse.util.async import ObservableDeferred
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.logcontext import make_deferred_yieldable, run_in_background
@@ -96,9 +97,9 @@ class PreviewUrlResource(Resource):
# XXX: if get_user_by_req fails, what should we do in an async render?
requester = yield self.auth.get_user_by_req(request)
- url = request.args.get("url")[0]
+ url = parse_string(request, "url")
if "ts" in request.args:
- ts = int(request.args.get("ts")[0])
+ ts = parse_integer(request, "ts")
else:
ts = self.clock.time_msec()
diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py
index 1a98120e1d..9b22d204a6 100644
--- a/synapse/rest/media/v1/upload_resource.py
+++ b/synapse/rest/media/v1/upload_resource.py
@@ -21,6 +21,7 @@ from twisted.web.server import NOT_DONE_YET
from synapse.api.errors import SynapseError
from synapse.http.server import respond_with_json, wrap_json_request_handler
+from synapse.http.servlet import parse_string
logger = logging.getLogger(__name__)
@@ -65,10 +66,10 @@ class UploadResource(Resource):
code=413,
)
- upload_name = request.args.get("filename", None)
+ upload_name = parse_string(request, "filename")
if upload_name:
try:
- upload_name = upload_name[0].decode('UTF-8')
+ upload_name = upload_name.decode('UTF-8')
except UnicodeDecodeError:
raise SynapseError(
msg="Invalid UTF-8 filename parameter: %r" % (upload_name),
diff --git a/synapse/streams/config.py b/synapse/streams/config.py
index 46ccbbda7d..451e4fa441 100644
--- a/synapse/streams/config.py
+++ b/synapse/streams/config.py
@@ -16,6 +16,7 @@
import logging
from synapse.api.errors import SynapseError
+from synapse.http.servlet import parse_integer, parse_string
from synapse.types import StreamToken
logger = logging.getLogger(__name__)
@@ -56,23 +57,10 @@ class PaginationConfig(object):
@classmethod
def from_request(cls, request, raise_invalid_params=True,
default_limit=None):
- def get_param(name, default=None):
- lst = request.args.get(name, [])
- if len(lst) > 1:
- raise SynapseError(
- 400, "%s must be specified only once" % (name,)
- )
- elif len(lst) == 1:
- return lst[0]
- else:
- return default
-
- direction = get_param("dir", 'f')
- if direction not in ['f', 'b']:
- raise SynapseError(400, "'dir' parameter is invalid.")
-
- from_tok = get_param("from")
- to_tok = get_param("to")
+ direction = parse_string(request, "dir", default='f', allowed_values=['f', 'b'])
+
+ from_tok = parse_string(request, "from")
+ to_tok = parse_string(request, "to")
try:
if from_tok == "END":
@@ -88,12 +76,10 @@ class PaginationConfig(object):
except Exception:
raise SynapseError(400, "'to' paramater is invalid")
- limit = get_param("limit", None)
- if limit is not None and not limit.isdigit():
- raise SynapseError(400, "'limit' parameter must be an integer.")
+ limit = parse_integer(request, "limit", default=default_limit)
- if limit is None:
- limit = default_limit
+ if limit and limit < 0:
+ raise SynapseError(400, "Limit must be 0 or above")
try:
return PaginationConfig(from_tok, to_tok, direction, limit)
diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index a1f8ff8f10..f2bde74dc5 100644
--- a/synapse/util/caches/stream_change_cache.py
+++ b/synapse/util/caches/stream_change_cache.py
@@ -80,12 +80,7 @@ class StreamChangeCache(object):
)
}
- # we need to include entities which we don't know about, as well as
- # those which are known to have changed since the stream pos.
- result = {
- e for e in entities
- if e in changed_entities or e not in self._entity_to_key
- }
+ result = changed_entities.intersection(entities)
self.metrics.inc_hits()
else:
diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py
index 5ac33b2132..7deb38f2a7 100644
--- a/synapse/util/ratelimitutils.py
+++ b/synapse/util/ratelimitutils.py
@@ -92,13 +92,22 @@ class _PerHostRatelimiter(object):
self.window_size = window_size
self.sleep_limit = sleep_limit
- self.sleep_msec = sleep_msec
+ self.sleep_sec = sleep_msec / 1000.0
self.reject_limit = reject_limit
self.concurrent_requests = concurrent_requests
+ # request_id objects for requests which have been slept
self.sleeping_requests = set()
+
+ # map from request_id object to Deferred for requests which are ready
+ # for processing but have been queued
self.ready_request_queue = collections.OrderedDict()
+
+ # request id objects for requests which are in progress
self.current_processing = set()
+
+ # times at which we have recently (within the last window_size ms)
+ # received requests.
self.request_times = []
@contextlib.contextmanager
@@ -117,11 +126,15 @@ class _PerHostRatelimiter(object):
def _on_enter(self, request_id):
time_now = self.clock.time_msec()
+
+ # remove any entries from request_times which aren't within the window
self.request_times[:] = [
r for r in self.request_times
if time_now - r < self.window_size
]
+ # reject the request if we already have too many queued up (either
+ # sleeping or in the ready queue).
queue_size = len(self.ready_request_queue) + len(self.sleeping_requests)
if queue_size > self.reject_limit:
raise LimitExceededError(
@@ -134,9 +147,13 @@ class _PerHostRatelimiter(object):
def queue_request():
if len(self.current_processing) > self.concurrent_requests:
- logger.debug("Ratelimit [%s]: Queue req", id(request_id))
queue_defer = defer.Deferred()
self.ready_request_queue[request_id] = queue_defer
+ logger.info(
+ "Ratelimiter: queueing request (queue now %i items)",
+ len(self.ready_request_queue),
+ )
+
return queue_defer
else:
return defer.succeed(None)
@@ -148,10 +165,9 @@ class _PerHostRatelimiter(object):
if len(self.request_times) > self.sleep_limit:
logger.debug(
- "Ratelimit [%s]: sleeping req",
- id(request_id),
+ "Ratelimiter: sleeping request for %f sec", self.sleep_sec,
)
- ret_defer = run_in_background(self.clock.sleep, self.sleep_msec / 1000.0)
+ ret_defer = run_in_background(self.clock.sleep, self.sleep_sec)
self.sleeping_requests.add(request_id)
@@ -200,11 +216,8 @@ class _PerHostRatelimiter(object):
)
self.current_processing.discard(request_id)
try:
- request_id, deferred = self.ready_request_queue.popitem()
-
- # XXX: why do we do the following? the on_start callback above will
- # do it for us.
- self.current_processing.add(request_id)
+ # start processing the next item on the queue.
+ _, deferred = self.ready_request_queue.popitem(last=False)
with PreserveLoggingContext():
deferred.callback(None)
|