diff --git a/synapse/rest/client/transactions.py b/synapse/rest/client/transactions.py
index fceca2edeb..00b1b3066e 100644
--- a/synapse/rest/client/transactions.py
+++ b/synapse/rest/client/transactions.py
@@ -17,37 +17,20 @@
to ensure idempotency when performing PUTs using the REST API."""
import logging
-from synapse.api.auth import get_access_token_from_request
from synapse.util.async import ObservableDeferred
+from synapse.util.logcontext import make_deferred_yieldable, run_in_background
logger = logging.getLogger(__name__)
-
-def get_transaction_key(request):
- """A helper function which returns a transaction key that can be used
- with TransactionCache for idempotent requests.
-
- Idempotency is based on the returned key being the same for separate
- requests to the same endpoint. The key is formed from the HTTP request
- path and the access_token for the requesting user.
-
- Args:
- request (twisted.web.http.Request): The incoming request. Must
- contain an access_token.
- Returns:
- str: A transaction key
- """
- token = get_access_token_from_request(request)
- return request.path + "/" + token
-
-
CLEANUP_PERIOD_MS = 1000 * 60 * 30 # 30 mins
class HttpTransactionCache(object):
- def __init__(self, clock):
- self.clock = clock
+ def __init__(self, hs):
+ self.hs = hs
+ self.auth = self.hs.get_auth()
+ self.clock = self.hs.get_clock()
self.transactions = {
# $txn_key: (ObservableDeferred<(res_code, res_json_body)>, timestamp)
}
@@ -55,6 +38,23 @@ class HttpTransactionCache(object):
# for at *LEAST* 30 mins, and at *MOST* 60 mins.
self.cleaner = self.clock.looping_call(self._cleanup, CLEANUP_PERIOD_MS)
+ def _get_transaction_key(self, request):
+ """A helper function which returns a transaction key that can be used
+ with TransactionCache for idempotent requests.
+
+ Idempotency is based on the returned key being the same for separate
+ requests to the same endpoint. The key is formed from the HTTP request
+ path and the access_token for the requesting user.
+
+ Args:
+ request (twisted.web.http.Request): The incoming request. Must
+ contain an access_token.
+ Returns:
+ str: A transaction key
+ """
+ token = self.auth.get_access_token_from_request(request)
+ return request.path + "/" + token
+
def fetch_or_execute_request(self, request, fn, *args, **kwargs):
"""A helper function for fetch_or_execute which extracts
a transaction key from the given request.
@@ -63,7 +63,7 @@ class HttpTransactionCache(object):
fetch_or_execute
"""
return self.fetch_or_execute(
- get_transaction_key(request), fn, *args, **kwargs
+ self._get_transaction_key(request), fn, *args, **kwargs
)
def fetch_or_execute(self, txn_key, fn, *args, **kwargs):
@@ -80,31 +80,30 @@ class HttpTransactionCache(object):
Returns:
Deferred which resolves to a tuple of (response_code, response_dict).
"""
- try:
- return self.transactions[txn_key][0].observe()
- except (KeyError, IndexError):
- pass # execute the function instead.
-
- deferred = fn(*args, **kwargs)
-
- # if the request fails with a Twisted failure, remove it
- # from the transaction map. This is done to ensure that we don't
- # cache transient errors like rate-limiting errors, etc.
- def remove_from_map(err):
- self.transactions.pop(txn_key, None)
- return err
- deferred.addErrback(remove_from_map)
-
- # We don't add any other errbacks to the raw deferred, so we ask
- # ObservableDeferred to swallow the error. This is fine as the error will
- # still be reported to the observers.
- observable = ObservableDeferred(deferred, consumeErrors=True)
- self.transactions[txn_key] = (observable, self.clock.time_msec())
- return observable.observe()
+ if txn_key in self.transactions:
+ observable = self.transactions[txn_key][0]
+ else:
+ # execute the function instead.
+ deferred = run_in_background(fn, *args, **kwargs)
+
+ observable = ObservableDeferred(deferred)
+ self.transactions[txn_key] = (observable, self.clock.time_msec())
+
+ # if the request fails with an exception, remove it
+ # from the transaction map. This is done to ensure that we don't
+ # cache transient errors like rate-limiting errors, etc.
+ def remove_from_map(err):
+ self.transactions.pop(txn_key, None)
+ # we deliberately do not propagate the error any further, as we
+ # expect the observers to have reported it.
+
+ deferred.addErrback(remove_from_map)
+
+ return make_deferred_yieldable(observable.observe())
def _cleanup(self):
now = self.clock.time_msec()
- for key in self.transactions.keys():
+ for key in list(self.transactions):
ts = self.transactions[key][1]
if now > (ts + CLEANUP_PERIOD_MS): # after cleanup period
del self.transactions[key]
diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py
index 7d786e8de3..99f6c6e3c3 100644
--- a/synapse/rest/client/v1/admin.py
+++ b/synapse/rest/client/v1/admin.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,17 +14,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import hashlib
+import hmac
+import logging
+
+from six.moves import http_client
+
from twisted.internet import defer
from synapse.api.constants import Membership
-from synapse.api.errors import AuthError, SynapseError
+from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
+from synapse.http.servlet import (
+ assert_params_in_dict,
+ parse_integer,
+ parse_json_object_from_request,
+ parse_string,
+)
from synapse.types import UserID, create_requester
-from synapse.http.servlet import parse_json_object_from_request
from .base import ClientV1RestServlet, client_path_patterns
-import logging
-
logger = logging.getLogger(__name__)
@@ -55,6 +65,125 @@ class UsersRestServlet(ClientV1RestServlet):
defer.returnValue((200, ret))
+class UserRegisterServlet(ClientV1RestServlet):
+ """
+ Attributes:
+ NONCE_TIMEOUT (int): Seconds until a generated nonce won't be accepted
+ nonces (dict[str, int]): The nonces that we will accept. A dict of
+ nonce to the time it was generated, in int seconds.
+ """
+ PATTERNS = client_path_patterns("/admin/register")
+ NONCE_TIMEOUT = 60
+
+ def __init__(self, hs):
+ super(UserRegisterServlet, self).__init__(hs)
+ self.handlers = hs.get_handlers()
+ self.reactor = hs.get_reactor()
+ self.nonces = {}
+ self.hs = hs
+
+ def _clear_old_nonces(self):
+ """
+ Clear out old nonces that are older than NONCE_TIMEOUT.
+ """
+ now = int(self.reactor.seconds())
+
+ for k, v in list(self.nonces.items()):
+ if now - v > self.NONCE_TIMEOUT:
+ del self.nonces[k]
+
+ def on_GET(self, request):
+ """
+ Generate a new nonce.
+ """
+ self._clear_old_nonces()
+
+ nonce = self.hs.get_secrets().token_hex(64)
+ self.nonces[nonce] = int(self.reactor.seconds())
+ return (200, {"nonce": nonce.encode('ascii')})
+
+ @defer.inlineCallbacks
+ def on_POST(self, request):
+ self._clear_old_nonces()
+
+ if not self.hs.config.registration_shared_secret:
+ raise SynapseError(400, "Shared secret registration is not enabled")
+
+ body = parse_json_object_from_request(request)
+
+ if "nonce" not in body:
+ raise SynapseError(
+ 400, "nonce must be specified", errcode=Codes.BAD_JSON,
+ )
+
+ nonce = body["nonce"]
+
+ if nonce not in self.nonces:
+ raise SynapseError(
+ 400, "unrecognised nonce",
+ )
+
+ # Delete the nonce, so it can't be reused, even if it's invalid
+ del self.nonces[nonce]
+
+ if "username" not in body:
+ raise SynapseError(
+ 400, "username must be specified", errcode=Codes.BAD_JSON,
+ )
+ else:
+ if (not isinstance(body['username'], str) or len(body['username']) > 512):
+ raise SynapseError(400, "Invalid username")
+
+ username = body["username"].encode("utf-8")
+ if b"\x00" in username:
+ raise SynapseError(400, "Invalid username")
+
+ if "password" not in body:
+ raise SynapseError(
+ 400, "password must be specified", errcode=Codes.BAD_JSON,
+ )
+ else:
+ if (not isinstance(body['password'], str) or len(body['password']) > 512):
+ raise SynapseError(400, "Invalid password")
+
+ password = body["password"].encode("utf-8")
+ if b"\x00" in password:
+ raise SynapseError(400, "Invalid password")
+
+ admin = body.get("admin", None)
+ got_mac = body["mac"]
+
+ want_mac = hmac.new(
+ key=self.hs.config.registration_shared_secret.encode(),
+ digestmod=hashlib.sha1,
+ )
+ want_mac.update(nonce)
+ want_mac.update(b"\x00")
+ want_mac.update(username)
+ want_mac.update(b"\x00")
+ want_mac.update(password)
+ want_mac.update(b"\x00")
+ want_mac.update(b"admin" if admin else b"notadmin")
+ want_mac = want_mac.hexdigest()
+
+ if not hmac.compare_digest(want_mac, got_mac):
+ raise SynapseError(
+ 403, "HMAC incorrect",
+ )
+
+ # Reuse the parts of RegisterRestServlet to reduce code duplication
+ from synapse.rest.client.v2_alpha.register import RegisterRestServlet
+ register = RegisterRestServlet(self.hs)
+
+ (user_id, _) = yield register.registration_handler.register(
+ localpart=username.lower(), password=password, admin=bool(admin),
+ generate_token=False,
+ )
+
+ result = yield register._create_registration_details(user_id, body)
+ defer.returnValue((200, result))
+
+
class WhoisRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/admin/whois/(?P<user_id>[^/]*)")
@@ -95,16 +224,8 @@ class PurgeMediaCacheRestServlet(ClientV1RestServlet):
if not is_admin:
raise AuthError(403, "You are not a server admin")
- before_ts = request.args.get("before_ts", None)
- if not before_ts:
- raise SynapseError(400, "Missing 'before_ts' arg")
-
- logger.info("before_ts: %r", before_ts[0])
-
- try:
- before_ts = int(before_ts[0])
- except Exception:
- raise SynapseError(400, "Invalid 'before_ts' arg")
+ before_ts = parse_integer(request, "before_ts", required=True)
+ logger.info("before_ts: %r", before_ts)
ret = yield self.media_repository.delete_old_remote_media(before_ts)
@@ -113,12 +234,18 @@ class PurgeMediaCacheRestServlet(ClientV1RestServlet):
class PurgeHistoryRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns(
- "/admin/purge_history/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
+ "/admin/purge_history/(?P<room_id>[^/]*)(/(?P<event_id>[^/]+))?"
)
def __init__(self, hs):
+ """
+
+ Args:
+ hs (synapse.server.HomeServer)
+ """
super(PurgeHistoryRestServlet, self).__init__(hs)
- self.handlers = hs.get_handlers()
+ self.pagination_handler = hs.get_pagination_handler()
+ self.store = hs.get_datastore()
@defer.inlineCallbacks
def on_POST(self, request, room_id, event_id):
@@ -128,20 +255,127 @@ class PurgeHistoryRestServlet(ClientV1RestServlet):
if not is_admin:
raise AuthError(403, "You are not a server admin")
- yield self.handlers.message_handler.purge_history(room_id, event_id)
+ body = parse_json_object_from_request(request, allow_empty_body=True)
- defer.returnValue((200, {}))
+ delete_local_events = bool(body.get("delete_local_events", False))
+
+ # establish the topological ordering we should keep events from. The
+ # user can provide an event_id in the URL or the request body, or can
+ # provide a timestamp in the request body.
+ if event_id is None:
+ event_id = body.get('purge_up_to_event_id')
+
+ if event_id is not None:
+ event = yield self.store.get_event(event_id)
+
+ if event.room_id != room_id:
+ raise SynapseError(400, "Event is for wrong room.")
+
+ token = yield self.store.get_topological_token_for_event(event_id)
+
+ logger.info(
+ "[purge] purging up to token %s (event_id %s)",
+ token, event_id,
+ )
+ elif 'purge_up_to_ts' in body:
+ ts = body['purge_up_to_ts']
+ if not isinstance(ts, int):
+ raise SynapseError(
+ 400, "purge_up_to_ts must be an int",
+ errcode=Codes.BAD_JSON,
+ )
+
+ stream_ordering = (
+ yield self.store.find_first_stream_ordering_after_ts(ts)
+ )
+
+ r = (
+ yield self.store.get_room_event_after_stream_ordering(
+ room_id, stream_ordering,
+ )
+ )
+ if not r:
+ logger.warn(
+ "[purge] purging events not possible: No event found "
+ "(received_ts %i => stream_ordering %i)",
+ ts, stream_ordering,
+ )
+ raise SynapseError(
+ 404,
+ "there is no event to be purged",
+ errcode=Codes.NOT_FOUND,
+ )
+ (stream, topo, _event_id) = r
+ token = "t%d-%d" % (topo, stream)
+ logger.info(
+ "[purge] purging up to token %s (received_ts %i => "
+ "stream_ordering %i)",
+ token, ts, stream_ordering,
+ )
+ else:
+ raise SynapseError(
+ 400,
+ "must specify purge_up_to_event_id or purge_up_to_ts",
+ errcode=Codes.BAD_JSON,
+ )
+
+ purge_id = yield self.pagination_handler.start_purge_history(
+ room_id, token,
+ delete_local_events=delete_local_events,
+ )
+
+ defer.returnValue((200, {
+ "purge_id": purge_id,
+ }))
+
+
+class PurgeHistoryStatusRestServlet(ClientV1RestServlet):
+ PATTERNS = client_path_patterns(
+ "/admin/purge_history_status/(?P<purge_id>[^/]+)"
+ )
+
+ def __init__(self, hs):
+ """
+
+ Args:
+ hs (synapse.server.HomeServer)
+ """
+ super(PurgeHistoryStatusRestServlet, self).__init__(hs)
+ self.pagination_handler = hs.get_pagination_handler()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, purge_id):
+ requester = yield self.auth.get_user_by_req(request)
+ is_admin = yield self.auth.is_server_admin(requester.user)
+
+ if not is_admin:
+ raise AuthError(403, "You are not a server admin")
+
+ purge_status = self.pagination_handler.get_purge_status(purge_id)
+ if purge_status is None:
+ raise NotFoundError("purge id '%s' not found" % purge_id)
+
+ defer.returnValue((200, purge_status.asdict()))
class DeactivateAccountRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/admin/deactivate/(?P<target_user_id>[^/]*)")
def __init__(self, hs):
- self.store = hs.get_datastore()
super(DeactivateAccountRestServlet, self).__init__(hs)
+ self._deactivate_account_handler = hs.get_deactivate_account_handler()
@defer.inlineCallbacks
def on_POST(self, request, target_user_id):
+ body = parse_json_object_from_request(request, allow_empty_body=True)
+ erase = body.get("erase", False)
+ if not isinstance(erase, bool):
+ raise SynapseError(
+ http_client.BAD_REQUEST,
+ "Param 'erase' must be a boolean, if given",
+ Codes.BAD_JSON,
+ )
+
UserID.from_string(target_user_id)
requester = yield self.auth.get_user_by_req(request)
is_admin = yield self.auth.is_server_admin(requester.user)
@@ -149,12 +383,9 @@ class DeactivateAccountRestServlet(ClientV1RestServlet):
if not is_admin:
raise AuthError(403, "You are not a server admin")
- # FIXME: Theoretically there is a race here wherein user resets password
- # using threepid.
- yield self.store.user_delete_access_tokens(target_user_id)
- yield self.store.user_delete_threepids(target_user_id)
- yield self.store.user_set_password_hash(target_user_id, None)
-
+ yield self._deactivate_account_handler.deactivate_account(
+ target_user_id, erase,
+ )
defer.returnValue((200, {}))
@@ -168,14 +399,16 @@ class ShutdownRoomRestServlet(ClientV1RestServlet):
DEFAULT_MESSAGE = (
"Sharing illegal content on this server is not permitted and rooms in"
- " violatation will be blocked."
+ " violation will be blocked."
)
def __init__(self, hs):
super(ShutdownRoomRestServlet, self).__init__(hs)
self.store = hs.get_datastore()
- self.handlers = hs.get_handlers()
self.state = hs.get_state_handler()
+ self._room_creation_handler = hs.get_room_creation_handler()
+ self.event_creation_handler = hs.get_event_creation_handler()
+ self.room_member_handler = hs.get_room_member_handler()
@defer.inlineCallbacks
def on_POST(self, request, room_id):
@@ -185,17 +418,15 @@ class ShutdownRoomRestServlet(ClientV1RestServlet):
raise AuthError(403, "You are not a server admin")
content = parse_json_object_from_request(request)
-
- new_room_user_id = content.get("new_room_user_id")
- if not new_room_user_id:
- raise SynapseError(400, "Please provide field `new_room_user_id`")
+ assert_params_in_dict(content, ["new_room_user_id"])
+ new_room_user_id = content["new_room_user_id"]
room_creator_requester = create_requester(new_room_user_id)
message = content.get("message", self.DEFAULT_MESSAGE)
room_name = content.get("room_name", "Content Violation Notification")
- info = yield self.handlers.room_creation_handler.create_room(
+ info = yield self._room_creation_handler.create_room(
room_creator_requester,
config={
"preset": "public_chat",
@@ -208,8 +439,7 @@ class ShutdownRoomRestServlet(ClientV1RestServlet):
)
new_room_id = info["room_id"]
- msg_handler = self.handlers.message_handler
- yield msg_handler.create_and_send_nonmember_event(
+ yield self.event_creation_handler.create_and_send_nonmember_event(
room_creator_requester,
{
"type": "m.room.message",
@@ -235,7 +465,7 @@ class ShutdownRoomRestServlet(ClientV1RestServlet):
logger.info("Kicking %r from %r...", user_id, room_id)
target_requester = create_requester(user_id)
- yield self.handlers.room_member_handler.update_membership(
+ yield self.room_member_handler.update_membership(
requester=target_requester,
target=target_requester.user,
room_id=room_id,
@@ -244,9 +474,9 @@ class ShutdownRoomRestServlet(ClientV1RestServlet):
ratelimit=False
)
- yield self.handlers.room_member_handler.forget(target_requester.user, room_id)
+ yield self.room_member_handler.forget(target_requester.user, room_id)
- yield self.handlers.room_member_handler.update_membership(
+ yield self.room_member_handler.update_membership(
requester=target_requester,
target=target_requester.user,
room_id=new_room_id,
@@ -294,9 +524,30 @@ class QuarantineMediaInRoom(ClientV1RestServlet):
defer.returnValue((200, {"num_quarantined": num_quarantined}))
+class ListMediaInRoom(ClientV1RestServlet):
+ """Lists all of the media in a given room.
+ """
+ PATTERNS = client_path_patterns("/admin/room/(?P<room_id>[^/]+)/media")
+
+ def __init__(self, hs):
+ super(ListMediaInRoom, self).__init__(hs)
+ self.store = hs.get_datastore()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, room_id):
+ requester = yield self.auth.get_user_by_req(request)
+ is_admin = yield self.auth.is_server_admin(requester.user)
+ if not is_admin:
+ raise AuthError(403, "You are not a server admin")
+
+ local_mxcs, remote_mxcs = yield self.store.get_media_mxcs_in_room(room_id)
+
+ defer.returnValue((200, {"local": local_mxcs, "remote": remote_mxcs}))
+
+
class ResetPasswordRestServlet(ClientV1RestServlet):
"""Post request to allow an administrator reset password for a user.
- This need a user have a administrator access in Synapse.
+ This needs user to have administrator access in Synapse.
Example:
http://localhost:8008/_matrix/client/api/v1/admin/reset_password/
@user:to_reset_password?access_token=admin_access_token
@@ -314,12 +565,12 @@ class ResetPasswordRestServlet(ClientV1RestServlet):
super(ResetPasswordRestServlet, self).__init__(hs)
self.hs = hs
self.auth = hs.get_auth()
- self.auth_handler = hs.get_auth_handler()
+ self._set_password_handler = hs.get_set_password_handler()
@defer.inlineCallbacks
def on_POST(self, request, target_user_id):
"""Post request to allow an administrator reset password for a user.
- This need a user have a administrator access in Synapse.
+ This needs user to have administrator access in Synapse.
"""
UserID.from_string(target_user_id)
requester = yield self.auth.get_user_by_req(request)
@@ -329,13 +580,12 @@ class ResetPasswordRestServlet(ClientV1RestServlet):
raise AuthError(403, "You are not a server admin")
params = parse_json_object_from_request(request)
+ assert_params_in_dict(params, ["new_password"])
new_password = params['new_password']
- if not new_password:
- raise SynapseError(400, "Missing 'new_password' arg")
logger.info("new_password: %r", new_password)
- yield self.auth_handler.set_password(
+ yield self._set_password_handler.set_password(
target_user_id, new_password, requester
)
defer.returnValue((200, {}))
@@ -343,7 +593,7 @@ class ResetPasswordRestServlet(ClientV1RestServlet):
class GetUsersPaginatedRestServlet(ClientV1RestServlet):
"""Get request to get specific number of users from Synapse.
- This need a user have a administrator access in Synapse.
+ This needs user to have administrator access in Synapse.
Example:
http://localhost:8008/_matrix/client/api/v1/admin/users_paginate/
@admin:user?access_token=admin_access_token&start=0&limit=10
@@ -362,7 +612,7 @@ class GetUsersPaginatedRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, target_user_id):
"""Get request to get specific number of users from Synapse.
- This need a user have a administrator access in Synapse.
+ This needs user to have administrator access in Synapse.
"""
target_user = UserID.from_string(target_user_id)
requester = yield self.auth.get_user_by_req(request)
@@ -379,12 +629,9 @@ class GetUsersPaginatedRestServlet(ClientV1RestServlet):
raise SynapseError(400, "Can only users a local user")
order = "name" # order by name in user table
- start = request.args.get("start")[0]
- limit = request.args.get("limit")[0]
- if not limit:
- raise SynapseError(400, "Missing 'limit' arg")
- if not start:
- raise SynapseError(400, "Missing 'start' arg")
+ start = parse_integer(request, "start", required=True)
+ limit = parse_integer(request, "limit", required=True)
+
logger.info("limit: %s, start: %s", limit, start)
ret = yield self.handlers.admin_handler.get_users_paginate(
@@ -395,7 +642,7 @@ class GetUsersPaginatedRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_POST(self, request, target_user_id):
"""Post request to get specific number of users from Synapse..
- This need a user have a administrator access in Synapse.
+ This needs user to have administrator access in Synapse.
Example:
http://localhost:8008/_matrix/client/api/v1/admin/users_paginate/
@admin:user?access_token=admin_access_token
@@ -416,12 +663,9 @@ class GetUsersPaginatedRestServlet(ClientV1RestServlet):
order = "name" # order by name in user table
params = parse_json_object_from_request(request)
+ assert_params_in_dict(params, ["limit", "start"])
limit = params['limit']
start = params['start']
- if not limit:
- raise SynapseError(400, "Missing 'limit' arg")
- if not start:
- raise SynapseError(400, "Missing 'start' arg")
logger.info("limit: %s, start: %s", limit, start)
ret = yield self.handlers.admin_handler.get_users_paginate(
@@ -433,7 +677,7 @@ class GetUsersPaginatedRestServlet(ClientV1RestServlet):
class SearchUsersRestServlet(ClientV1RestServlet):
"""Get request to search user table for specific users according to
search term.
- This need a user have a administrator access in Synapse.
+ This needs user to have administrator access in Synapse.
Example:
http://localhost:8008/_matrix/client/api/v1/admin/search_users/
@admin:user?access_token=admin_access_token&term=alice
@@ -453,7 +697,7 @@ class SearchUsersRestServlet(ClientV1RestServlet):
def on_GET(self, request, target_user_id):
"""Get request to search user table for specific users according to
search term.
- This need a user have a administrator access in Synapse.
+ This needs user to have a administrator access in Synapse.
"""
target_user = UserID.from_string(target_user_id)
requester = yield self.auth.get_user_by_req(request)
@@ -469,10 +713,7 @@ class SearchUsersRestServlet(ClientV1RestServlet):
if not self.hs.is_mine(target_user):
raise SynapseError(400, "Can only users a local user")
- term = request.args.get("term")[0]
- if not term:
- raise SynapseError(400, "Missing 'term' arg")
-
+ term = parse_string(request, "term", required=True)
logger.info("term: %s ", term)
ret = yield self.handlers.admin_handler.search_users(
@@ -484,6 +725,7 @@ class SearchUsersRestServlet(ClientV1RestServlet):
def register_servlets(hs, http_server):
WhoisRestServlet(hs).register(http_server)
PurgeMediaCacheRestServlet(hs).register(http_server)
+ PurgeHistoryStatusRestServlet(hs).register(http_server)
DeactivateAccountRestServlet(hs).register(http_server)
PurgeHistoryRestServlet(hs).register(http_server)
UsersRestServlet(hs).register(http_server)
@@ -492,3 +734,5 @@ def register_servlets(hs, http_server):
SearchUsersRestServlet(hs).register(http_server)
ShutdownRoomRestServlet(hs).register(http_server)
QuarantineMediaInRoom(hs).register(http_server)
+ ListMediaInRoom(hs).register(http_server)
+ UserRegisterServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v1/base.py b/synapse/rest/client/v1/base.py
index c7aa0bbf59..c77d7aba68 100644
--- a/synapse/rest/client/v1/base.py
+++ b/synapse/rest/client/v1/base.py
@@ -16,14 +16,12 @@
"""This module contains base REST classes for constructing client v1 servlets.
"""
-from synapse.http.servlet import RestServlet
-from synapse.api.urls import CLIENT_PREFIX
-from synapse.rest.client.transactions import HttpTransactionCache
-
-import re
-
import logging
+import re
+from synapse.api.urls import CLIENT_PREFIX
+from synapse.http.servlet import RestServlet
+from synapse.rest.client.transactions import HttpTransactionCache
logger = logging.getLogger(__name__)
@@ -52,6 +50,10 @@ class ClientV1RestServlet(RestServlet):
"""A base Synapse REST Servlet for the client version 1 API.
"""
+ # This subclass was presumably created to allow the auth for the v1
+ # protocol version to be different, however this behaviour was removed.
+ # it may no longer be necessary
+
def __init__(self, hs):
"""
Args:
@@ -59,5 +61,5 @@ class ClientV1RestServlet(RestServlet):
"""
self.hs = hs
self.builder_factory = hs.get_event_builder_factory()
- self.auth = hs.get_v1auth()
- self.txns = HttpTransactionCache(hs.get_clock())
+ self.auth = hs.get_auth()
+ self.txns = HttpTransactionCache(hs)
diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py
index f15aa5c13f..69dcd618cb 100644
--- a/synapse/rest/client/v1/directory.py
+++ b/synapse/rest/client/v1/directory.py
@@ -14,17 +14,16 @@
# limitations under the License.
+import logging
+
from twisted.internet import defer
-from synapse.api.errors import AuthError, SynapseError, Codes
-from synapse.types import RoomAlias
+from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.http.servlet import parse_json_object_from_request
+from synapse.types import RoomAlias
from .base import ClientV1RestServlet, client_path_patterns
-import logging
-
-
logger = logging.getLogger(__name__)
@@ -53,15 +52,14 @@ class ClientDirectoryServer(ClientV1RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, room_alias):
+ room_alias = RoomAlias.from_string(room_alias)
+
content = parse_json_object_from_request(request)
if "room_id" not in content:
- raise SynapseError(400, "Missing room_id key",
+ raise SynapseError(400, 'Missing params: ["room_id"]',
errcode=Codes.BAD_JSON)
logger.debug("Got content: %s", content)
-
- room_alias = RoomAlias.from_string(room_alias)
-
logger.debug("Got room name: %s", room_alias.to_string())
room_id = content["room_id"]
@@ -93,7 +91,7 @@ class ClientDirectoryServer(ClientV1RestServlet):
)
except SynapseError as e:
raise e
- except:
+ except Exception:
logger.exception("Failed to create association")
raise
except AuthError:
diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py
index 701b6f549b..b70c9c2806 100644
--- a/synapse/rest/client/v1/events.py
+++ b/synapse/rest/client/v1/events.py
@@ -14,15 +14,15 @@
# limitations under the License.
"""This module contains REST servlets to do with event streaming, /events."""
+import logging
+
from twisted.internet import defer
from synapse.api.errors import SynapseError
-from synapse.streams.config import PaginationConfig
-from .base import ClientV1RestServlet, client_path_patterns
from synapse.events.utils import serialize_event
+from synapse.streams.config import PaginationConfig
-import logging
-
+from .base import ClientV1RestServlet, client_path_patterns
logger = logging.getLogger(__name__)
diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py
index 478e21eea8..fd5f85b53e 100644
--- a/synapse/rest/client/v1/initial_sync.py
+++ b/synapse/rest/client/v1/initial_sync.py
@@ -15,7 +15,9 @@
from twisted.internet import defer
+from synapse.http.servlet import parse_boolean
from synapse.streams.config import PaginationConfig
+
from .base import ClientV1RestServlet, client_path_patterns
@@ -32,7 +34,7 @@ class InitialSyncRestServlet(ClientV1RestServlet):
requester = yield self.auth.get_user_by_req(request)
as_client_event = "raw" not in request.args
pagination_config = PaginationConfig.from_request(request)
- include_archived = request.args.get("archived", None) == ["true"]
+ include_archived = parse_boolean(request, "archived", default=False)
content = yield self.initial_sync_handler.snapshot_all_rooms(
user_id=requester.user.to_string(),
pagin_config=pagination_config,
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index a43410fb37..cb85fa1436 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -13,30 +13,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+import urllib
+import xml.etree.ElementTree as ET
+
+from six.moves.urllib import parse as urlparse
+
+from canonicaljson import json
+from saml2 import BINDING_HTTP_POST, config
+from saml2.client import Saml2Client
+
from twisted.internet import defer
+from twisted.web.client import PartialDownloadError
-from synapse.api.errors import SynapseError, LoginError, Codes
-from synapse.types import UserID
+from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.http.server import finish_request
from synapse.http.servlet import parse_json_object_from_request
+from synapse.types import UserID
from synapse.util.msisdn import phone_number_to_msisdn
from .base import ClientV1RestServlet, client_path_patterns
-import simplejson as json
-import urllib
-import urlparse
-
-import logging
-from saml2 import BINDING_HTTP_POST
-from saml2 import config
-from saml2.client import Saml2Client
-
-import xml.etree.ElementTree as ET
-
-from twisted.web.client import PartialDownloadError
-
-
logger = logging.getLogger(__name__)
@@ -85,7 +82,6 @@ def login_id_thirdparty_from_phone(identifier):
class LoginRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/login$")
- PASS_TYPE = "m.login.password"
SAML2_TYPE = "m.login.saml2"
CAS_TYPE = "m.login.cas"
TOKEN_TYPE = "m.login.token"
@@ -94,7 +90,6 @@ class LoginRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(LoginRestServlet, self).__init__(hs)
self.idp_redirect_url = hs.config.saml2_idp_redirect_url
- self.password_enabled = hs.config.password_enabled
self.saml2_enabled = hs.config.saml2_enabled
self.jwt_enabled = hs.config.jwt_enabled
self.jwt_secret = hs.config.jwt_secret
@@ -121,8 +116,10 @@ class LoginRestServlet(ClientV1RestServlet):
# fall back to the fallback API if they don't understand one of the
# login flow types returned.
flows.append({"type": LoginRestServlet.TOKEN_TYPE})
- if self.password_enabled:
- flows.append({"type": LoginRestServlet.PASS_TYPE})
+
+ flows.extend((
+ {"type": t} for t in self.auth_handler.get_supported_login_types()
+ ))
return (200, {"flows": flows})
@@ -133,14 +130,8 @@ class LoginRestServlet(ClientV1RestServlet):
def on_POST(self, request):
login_submission = parse_json_object_from_request(request)
try:
- if login_submission["type"] == LoginRestServlet.PASS_TYPE:
- if not self.password_enabled:
- raise SynapseError(400, "Password login has been disabled.")
-
- result = yield self.do_password_login(login_submission)
- defer.returnValue(result)
- elif self.saml2_enabled and (login_submission["type"] ==
- LoginRestServlet.SAML2_TYPE):
+ if self.saml2_enabled and (login_submission["type"] ==
+ LoginRestServlet.SAML2_TYPE):
relay_state = ""
if "relay_state" in login_submission:
relay_state = "&RelayState=" + urllib.quote(
@@ -157,15 +148,31 @@ class LoginRestServlet(ClientV1RestServlet):
result = yield self.do_token_login(login_submission)
defer.returnValue(result)
else:
- raise SynapseError(400, "Bad login type.")
+ result = yield self._do_other_login(login_submission)
+ defer.returnValue(result)
except KeyError:
raise SynapseError(400, "Missing JSON keys.")
@defer.inlineCallbacks
- def do_password_login(self, login_submission):
- if "password" not in login_submission:
- raise SynapseError(400, "Missing parameter: password")
+ def _do_other_login(self, login_submission):
+ """Handle non-token/saml/jwt logins
+
+ Args:
+ login_submission:
+ Returns:
+ (int, object): HTTP code/response
+ """
+ # Log the request we got, but only certain fields to minimise the chance of
+ # logging someone's password (even if they accidentally put it in the wrong
+ # field)
+ logger.info(
+ "Got login request with identifier: %r, medium: %r, address: %r, user: %r",
+ login_submission.get('identifier'),
+ login_submission.get('medium'),
+ login_submission.get('address'),
+ login_submission.get('user'),
+ )
login_submission_legacy_convert(login_submission)
if "identifier" not in login_submission:
@@ -181,19 +188,25 @@ class LoginRestServlet(ClientV1RestServlet):
# convert threepid identifiers to user IDs
if identifier["type"] == "m.id.thirdparty":
- if 'medium' not in identifier or 'address' not in identifier:
+ address = identifier.get('address')
+ medium = identifier.get('medium')
+
+ if medium is None or address is None:
raise SynapseError(400, "Invalid thirdparty identifier")
- address = identifier['address']
- if identifier['medium'] == 'email':
+ if medium == 'email':
# For emails, transform the address to lowercase.
# We store all email addreses as lowercase in the DB.
# (See add_threepid in synapse/handlers/auth.py)
address = address.lower()
user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
- identifier['medium'], address
+ medium, address,
)
if not user_id:
+ logger.warn(
+ "unknown 3pid identifier medium %s, address %r",
+ medium, address,
+ )
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
identifier = {
@@ -208,30 +221,29 @@ class LoginRestServlet(ClientV1RestServlet):
if "user" not in identifier:
raise SynapseError(400, "User identifier is missing 'user' key")
- user_id = identifier["user"]
-
- if not user_id.startswith('@'):
- user_id = UserID.create(
- user_id, self.hs.hostname
- ).to_string()
-
auth_handler = self.auth_handler
- user_id = yield auth_handler.validate_password_login(
- user_id=user_id,
- password=login_submission["password"],
+ canonical_user_id, callback = yield auth_handler.validate_login(
+ identifier["user"],
+ login_submission,
+ )
+
+ device_id = yield self._register_device(
+ canonical_user_id, login_submission,
)
- device_id = yield self._register_device(user_id, login_submission)
access_token = yield auth_handler.get_access_token_for_user_id(
- user_id, device_id,
- login_submission.get("initial_device_display_name"),
+ canonical_user_id, device_id,
)
+
result = {
- "user_id": user_id, # may have changed
+ "user_id": canonical_user_id,
"access_token": access_token,
"home_server": self.hs.hostname,
"device_id": device_id,
}
+ if callback is not None:
+ yield callback(result)
+
defer.returnValue((200, result))
@defer.inlineCallbacks
@@ -244,7 +256,6 @@ class LoginRestServlet(ClientV1RestServlet):
device_id = yield self._register_device(user_id, login_submission)
access_token = yield auth_handler.get_access_token_for_user_id(
user_id, device_id,
- login_submission.get("initial_device_display_name"),
)
result = {
"user_id": user_id, # may have changed
@@ -278,7 +289,7 @@ class LoginRestServlet(ClientV1RestServlet):
if user is None:
raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
- user_id = UserID.create(user, self.hs.hostname).to_string()
+ user_id = UserID(user, self.hs.hostname).to_string()
auth_handler = self.auth_handler
registered_user_id = yield auth_handler.check_user_exists(user_id)
if registered_user_id:
@@ -287,7 +298,6 @@ class LoginRestServlet(ClientV1RestServlet):
)
access_token = yield auth_handler.get_access_token_for_user_id(
registered_user_id, device_id,
- login_submission.get("initial_device_display_name"),
)
result = {
@@ -444,7 +454,7 @@ class CasTicketServlet(ClientV1RestServlet):
if required_value != actual_value:
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
- user_id = UserID.create(user, self.hs.hostname).to_string()
+ user_id = UserID(user, self.hs.hostname).to_string()
auth_handler = self.auth_handler
registered_user_id = yield auth_handler.check_user_exists(user_id)
if not registered_user_id:
diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py
index 1358d0acab..430c692336 100644
--- a/synapse/rest/client/v1/logout.py
+++ b/synapse/rest/client/v1/logout.py
@@ -13,15 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+
from twisted.internet import defer
-from synapse.api.auth import get_access_token_from_request
+from synapse.api.errors import AuthError
from .base import ClientV1RestServlet, client_path_patterns
-import logging
-
-
logger = logging.getLogger(__name__)
@@ -30,15 +29,33 @@ class LogoutRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(LogoutRestServlet, self).__init__(hs)
- self.store = hs.get_datastore()
+ self._auth = hs.get_auth()
+ self._auth_handler = hs.get_auth_handler()
+ self._device_handler = hs.get_device_handler()
def on_OPTIONS(self, request):
return (200, {})
@defer.inlineCallbacks
def on_POST(self, request):
- access_token = get_access_token_from_request(request)
- yield self.store.delete_access_token(access_token)
+ try:
+ requester = yield self.auth.get_user_by_req(request)
+ except AuthError:
+ # this implies the access token has already been deleted.
+ defer.returnValue((401, {
+ "errcode": "M_UNKNOWN_TOKEN",
+ "error": "Access Token unknown or expired"
+ }))
+ else:
+ if requester.device_id is None:
+ # the acccess token wasn't associated with a device.
+ # Just delete the access token
+ access_token = self._auth.get_access_token_from_request(request)
+ yield self._auth_handler.delete_access_token(access_token)
+ else:
+ yield self._device_handler.delete_device(
+ requester.user.to_string(), requester.device_id)
+
defer.returnValue((200, {}))
@@ -47,8 +64,9 @@ class LogoutAllRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(LogoutAllRestServlet, self).__init__(hs)
- self.store = hs.get_datastore()
self.auth = hs.get_auth()
+ self._auth_handler = hs.get_auth_handler()
+ self._device_handler = hs.get_device_handler()
def on_OPTIONS(self, request):
return (200, {})
@@ -57,7 +75,13 @@ class LogoutAllRestServlet(ClientV1RestServlet):
def on_POST(self, request):
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
- yield self.store.user_delete_access_tokens(user_id)
+
+ # first delete all of the user's devices
+ yield self._device_handler.delete_all_devices_for_user(user_id)
+
+ # .. and then delete any access tokens which weren't associated with
+ # devices.
+ yield self._auth_handler.delete_access_tokens_for_user(user_id)
defer.returnValue((200, {}))
diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py
index 47b2dc45e7..a14f0c807e 100644
--- a/synapse/rest/client/v1/presence.py
+++ b/synapse/rest/client/v1/presence.py
@@ -15,15 +15,18 @@
""" This module contains REST servlets to do with presence: /presence/<paths>
"""
+import logging
+
+from six import string_types
+
from twisted.internet import defer
-from synapse.api.errors import SynapseError, AuthError
-from synapse.types import UserID
+from synapse.api.errors import AuthError, SynapseError
from synapse.handlers.presence import format_user_presence_state
from synapse.http.servlet import parse_json_object_from_request
-from .base import ClientV1RestServlet, client_path_patterns
+from synapse.types import UserID
-import logging
+from .base import ClientV1RestServlet, client_path_patterns
logger = logging.getLogger(__name__)
@@ -71,14 +74,14 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
if "status_msg" in content:
state["status_msg"] = content.pop("status_msg")
- if not isinstance(state["status_msg"], basestring):
+ if not isinstance(state["status_msg"], string_types):
raise SynapseError(400, "status_msg must be a string.")
if content:
raise KeyError()
except SynapseError as e:
raise e
- except:
+ except Exception:
raise SynapseError(400, "Unable to parse state")
yield self.presence_handler.set_state(user, state)
@@ -129,7 +132,7 @@ class PresenceListRestServlet(ClientV1RestServlet):
if "invite" in content:
for u in content["invite"]:
- if not isinstance(u, basestring):
+ if not isinstance(u, string_types):
raise SynapseError(400, "Bad invite value.")
if len(u) == 0:
continue
@@ -140,7 +143,7 @@ class PresenceListRestServlet(ClientV1RestServlet):
if "drop" in content:
for u in content["drop"]:
- if not isinstance(u, basestring):
+ if not isinstance(u, string_types):
raise SynapseError(400, "Bad drop value.")
if len(u) == 0:
continue
diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py
index 1a5045c9ec..a23edd8fe5 100644
--- a/synapse/rest/client/v1/profile.py
+++ b/synapse/rest/client/v1/profile.py
@@ -16,9 +16,10 @@
""" This module contains REST servlets to do with profile: /profile/<paths> """
from twisted.internet import defer
-from .base import ClientV1RestServlet, client_path_patterns
-from synapse.types import UserID
from synapse.http.servlet import parse_json_object_from_request
+from synapse.types import UserID
+
+from .base import ClientV1RestServlet, client_path_patterns
class ProfileDisplaynameRestServlet(ClientV1RestServlet):
@@ -26,13 +27,13 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(ProfileDisplaynameRestServlet, self).__init__(hs)
- self.handlers = hs.get_handlers()
+ self.profile_handler = hs.get_profile_handler()
@defer.inlineCallbacks
def on_GET(self, request, user_id):
user = UserID.from_string(user_id)
- displayname = yield self.handlers.profile_handler.get_displayname(
+ displayname = yield self.profile_handler.get_displayname(
user,
)
@@ -52,10 +53,10 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
try:
new_name = content["displayname"]
- except:
+ except Exception:
defer.returnValue((400, "Unable to parse name"))
- yield self.handlers.profile_handler.set_displayname(
+ yield self.profile_handler.set_displayname(
user, requester, new_name, is_admin)
defer.returnValue((200, {}))
@@ -69,13 +70,13 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(ProfileAvatarURLRestServlet, self).__init__(hs)
- self.handlers = hs.get_handlers()
+ self.profile_handler = hs.get_profile_handler()
@defer.inlineCallbacks
def on_GET(self, request, user_id):
user = UserID.from_string(user_id)
- avatar_url = yield self.handlers.profile_handler.get_avatar_url(
+ avatar_url = yield self.profile_handler.get_avatar_url(
user,
)
@@ -94,10 +95,10 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
content = parse_json_object_from_request(request)
try:
new_name = content["avatar_url"]
- except:
+ except Exception:
defer.returnValue((400, "Unable to parse name"))
- yield self.handlers.profile_handler.set_avatar_url(
+ yield self.profile_handler.set_avatar_url(
user, requester, new_name, is_admin)
defer.returnValue((200, {}))
@@ -111,16 +112,16 @@ class ProfileRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(ProfileRestServlet, self).__init__(hs)
- self.handlers = hs.get_handlers()
+ self.profile_handler = hs.get_profile_handler()
@defer.inlineCallbacks
def on_GET(self, request, user_id):
user = UserID.from_string(user_id)
- displayname = yield self.handlers.profile_handler.get_displayname(
+ displayname = yield self.profile_handler.get_displayname(
user,
)
- avatar_url = yield self.handlers.profile_handler.get_avatar_url(
+ avatar_url = yield self.profile_handler.get_avatar_url(
user,
)
diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py
index 6bb4821ec6..6e95d9bec2 100644
--- a/synapse/rest/client/v1/push_rule.py
+++ b/synapse/rest/client/v1/push_rule.py
@@ -16,16 +16,18 @@
from twisted.internet import defer
from synapse.api.errors import (
- SynapseError, UnrecognizedRequestError, NotFoundError, StoreError
+ NotFoundError,
+ StoreError,
+ SynapseError,
+ UnrecognizedRequestError,
)
-from .base import ClientV1RestServlet, client_path_patterns
-from synapse.storage.push_rule import (
- InconsistentRuleException, RuleNotFoundException
-)
-from synapse.push.clientformat import format_push_rules_for_user
+from synapse.http.servlet import parse_json_value_from_request, parse_string
from synapse.push.baserules import BASE_RULE_IDS
+from synapse.push.clientformat import format_push_rules_for_user
from synapse.push.rulekinds import PRIORITY_CLASS_MAP
-from synapse.http.servlet import parse_json_value_from_request
+from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
+
+from .base import ClientV1RestServlet, client_path_patterns
class PushRuleRestServlet(ClientV1RestServlet):
@@ -73,13 +75,13 @@ class PushRuleRestServlet(ClientV1RestServlet):
except InvalidRuleException as e:
raise SynapseError(400, e.message)
- before = request.args.get("before", None)
+ before = parse_string(request, "before")
if before:
- before = _namespaced_rule_id(spec, before[0])
+ before = _namespaced_rule_id(spec, before)
- after = request.args.get("after", None)
+ after = parse_string(request, "after")
if after:
- after = _namespaced_rule_id(spec, after[0])
+ after = _namespaced_rule_id(spec, after)
try:
yield self.store.add_push_rule(
diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py
index 9a2ed6ed88..182a68b1e2 100644
--- a/synapse/rest/client/v1/pusher.py
+++ b/synapse/rest/client/v1/pusher.py
@@ -13,20 +13,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+
from twisted.internet import defer
-from synapse.api.errors import SynapseError, Codes
-from synapse.push import PusherConfigException
+from synapse.api.errors import Codes, StoreError, SynapseError
+from synapse.http.server import finish_request
from synapse.http.servlet import (
- parse_json_object_from_request, parse_string, RestServlet
+ RestServlet,
+ assert_params_in_dict,
+ parse_json_object_from_request,
+ parse_string,
)
-from synapse.http.server import finish_request
-from synapse.api.errors import StoreError
+from synapse.push import PusherConfigException
from .base import ClientV1RestServlet, client_path_patterns
-import logging
-
logger = logging.getLogger(__name__)
@@ -73,6 +75,7 @@ class PushersSetRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(PushersSetRestServlet, self).__init__(hs)
self.notifier = hs.get_notifier()
+ self.pusher_pool = self.hs.get_pusherpool()
@defer.inlineCallbacks
def on_POST(self, request):
@@ -81,25 +84,19 @@ class PushersSetRestServlet(ClientV1RestServlet):
content = parse_json_object_from_request(request)
- pusher_pool = self.hs.get_pusherpool()
-
if ('pushkey' in content and 'app_id' in content
and 'kind' in content and
content['kind'] is None):
- yield pusher_pool.remove_pusher(
+ yield self.pusher_pool.remove_pusher(
content['app_id'], content['pushkey'], user_id=user.to_string()
)
defer.returnValue((200, {}))
- reqd = ['kind', 'app_id', 'app_display_name',
- 'device_display_name', 'pushkey', 'lang', 'data']
- missing = []
- for i in reqd:
- if i not in content:
- missing.append(i)
- if len(missing):
- raise SynapseError(400, "Missing parameters: " + ','.join(missing),
- errcode=Codes.MISSING_PARAM)
+ assert_params_in_dict(
+ content,
+ ['kind', 'app_id', 'app_display_name',
+ 'device_display_name', 'pushkey', 'lang', 'data']
+ )
logger.debug("set pushkey %s to kind %s", content['pushkey'], content['kind'])
logger.debug("Got pushers request with body: %r", content)
@@ -109,14 +106,14 @@ class PushersSetRestServlet(ClientV1RestServlet):
append = content['append']
if not append:
- yield pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user(
+ yield self.pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user(
app_id=content['app_id'],
pushkey=content['pushkey'],
not_user_id=user.to_string()
)
try:
- yield pusher_pool.add_pusher(
+ yield self.pusher_pool.add_pusher(
user_id=user.to_string(),
access_token=requester.access_token_id,
kind=content['kind'],
@@ -148,10 +145,11 @@ class PushersRemoveRestServlet(RestServlet):
SUCCESS_HTML = "<html><body>You have been unsubscribed</body><html>"
def __init__(self, hs):
- super(RestServlet, self).__init__()
+ super(PushersRemoveRestServlet, self).__init__()
self.hs = hs
self.notifier = hs.get_notifier()
- self.auth = hs.get_v1auth()
+ self.auth = hs.get_auth()
+ self.pusher_pool = self.hs.get_pusherpool()
@defer.inlineCallbacks
def on_GET(self, request):
@@ -161,10 +159,8 @@ class PushersRemoveRestServlet(RestServlet):
app_id = parse_string(request, "app_id", required=True)
pushkey = parse_string(request, "pushkey", required=True)
- pusher_pool = self.hs.get_pusherpool()
-
try:
- yield pusher_pool.remove_pusher(
+ yield self.pusher_pool.remove_pusher(
app_id=app_id,
pushkey=pushkey,
user_id=user.to_string(),
@@ -178,7 +174,6 @@ class PushersRemoveRestServlet(RestServlet):
request.setResponseCode(200)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
- request.setHeader(b"Server", self.hs.version_string)
request.setHeader(b"Content-Length", b"%d" % (
len(PushersRemoveRestServlet.SUCCESS_HTML),
))
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index cd388770c8..b7bd878c90 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,22 +15,28 @@
# limitations under the License.
""" This module contains REST servlets to do with rooms: /rooms/<paths> """
+import logging
+
+from six.moves.urllib import parse as urlparse
+
+from canonicaljson import json
+
from twisted.internet import defer
-from .base import ClientV1RestServlet, client_path_patterns
-from synapse.api.errors import SynapseError, Codes, AuthError
-from synapse.streams.config import PaginationConfig
from synapse.api.constants import EventTypes, Membership
+from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.api.filtering import Filter
-from synapse.types import UserID, RoomID, RoomAlias, ThirdPartyInstanceID
-from synapse.events.utils import serialize_event, format_event_for_client_v2
+from synapse.events.utils import format_event_for_client_v2, serialize_event
from synapse.http.servlet import (
- parse_json_object_from_request, parse_string, parse_integer
+ assert_params_in_dict,
+ parse_integer,
+ parse_json_object_from_request,
+ parse_string,
)
+from synapse.streams.config import PaginationConfig
+from synapse.types import RoomAlias, RoomID, ThirdPartyInstanceID, UserID
-import logging
-import urllib
-import ujson as json
+from .base import ClientV1RestServlet, client_path_patterns
logger = logging.getLogger(__name__)
@@ -39,7 +46,7 @@ class RoomCreateRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomCreateRestServlet, self).__init__(hs)
- self.handlers = hs.get_handlers()
+ self._room_creation_handler = hs.get_room_creation_handler()
def register(self, http_server):
PATTERNS = "/createRoom"
@@ -62,8 +69,7 @@ class RoomCreateRestServlet(ClientV1RestServlet):
def on_POST(self, request):
requester = yield self.auth.get_user_by_req(request)
- handler = self.handlers.room_creation_handler
- info = yield handler.create_room(
+ info = yield self._room_creation_handler.create_room(
requester, self.get_room_config(request)
)
@@ -82,6 +88,9 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomStateEventRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
+ self.event_creation_hander = hs.get_event_creation_handler()
+ self.room_member_handler = hs.get_room_member_handler()
+ self.message_handler = hs.get_message_handler()
def register(self, http_server):
# /room/$roomid/state/$eventtype
@@ -116,7 +125,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
format = parse_string(request, "format", default="content",
allowed_values=["content", "event"])
- msg_handler = self.handlers.message_handler
+ msg_handler = self.message_handler
data = yield msg_handler.get_room_data(
user_id=requester.user.to_string(),
room_id=room_id,
@@ -154,7 +163,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
if event_type == EventTypes.Member:
membership = content.get("membership", None)
- event = yield self.handlers.room_member_handler.update_membership(
+ event = yield self.room_member_handler.update_membership(
requester,
target=UserID.from_string(state_key),
room_id=room_id,
@@ -162,16 +171,12 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
content=content,
)
else:
- msg_handler = self.handlers.message_handler
- event, context = yield msg_handler.create_event(
+ event = yield self.event_creation_hander.create_and_send_nonmember_event(
requester,
event_dict,
- token_id=requester.access_token_id,
txn_id=txn_id,
)
- yield msg_handler.send_nonmember_event(requester, event, context)
-
ret = {}
if event:
ret = {"event_id": event.event_id}
@@ -183,7 +188,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomSendEventRestServlet, self).__init__(hs)
- self.handlers = hs.get_handlers()
+ self.event_creation_hander = hs.get_event_creation_handler()
def register(self, http_server):
# /rooms/$roomid/send/$event_type[/$txn_id]
@@ -195,15 +200,19 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
content = parse_json_object_from_request(request)
- msg_handler = self.handlers.message_handler
- event = yield msg_handler.create_and_send_nonmember_event(
+ event_dict = {
+ "type": event_type,
+ "content": content,
+ "room_id": room_id,
+ "sender": requester.user.to_string(),
+ }
+
+ if 'ts' in request.args and requester.app_service:
+ event_dict['origin_server_ts'] = parse_integer(request, "ts", 0)
+
+ event = yield self.event_creation_hander.create_and_send_nonmember_event(
requester,
- {
- "type": event_type,
- "content": content,
- "room_id": room_id,
- "sender": requester.user.to_string(),
- },
+ event_dict,
txn_id=txn_id,
)
@@ -222,7 +231,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
class JoinRoomAliasServlet(ClientV1RestServlet):
def __init__(self, hs):
super(JoinRoomAliasServlet, self).__init__(hs)
- self.handlers = hs.get_handlers()
+ self.room_member_handler = hs.get_room_member_handler()
def register(self, http_server):
# /join/$room_identifier[/$txn_id]
@@ -238,7 +247,7 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
try:
content = parse_json_object_from_request(request)
- except:
+ except Exception:
# Turns out we used to ignore the body entirely, and some clients
# cheekily send invalid bodies.
content = {}
@@ -247,10 +256,10 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
room_id = room_identifier
try:
remote_room_hosts = request.args["server_name"]
- except:
+ except Exception:
remote_room_hosts = None
elif RoomAlias.is_valid(room_identifier):
- handler = self.handlers.room_member_handler
+ handler = self.room_member_handler
room_alias = RoomAlias.from_string(room_identifier)
room_id, remote_room_hosts = yield handler.lookup_room_alias(room_alias)
room_id = room_id.to_string()
@@ -259,7 +268,7 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
room_identifier,
))
- yield self.handlers.room_member_handler.update_membership(
+ yield self.room_member_handler.update_membership(
requester=requester,
target=requester.user,
room_id=room_id,
@@ -369,14 +378,13 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomMemberListRestServlet, self).__init__(hs)
- self.handlers = hs.get_handlers()
+ self.message_handler = hs.get_message_handler()
@defer.inlineCallbacks
def on_GET(self, request, room_id):
# TODO support Pagination stream API (limit/tokens)
requester = yield self.auth.get_user_by_req(request)
- handler = self.handlers.message_handler
- events = yield handler.get_state_events(
+ events = yield self.message_handler.get_state_events(
room_id=room_id,
user_id=requester.user.to_string(),
)
@@ -398,22 +406,18 @@ class JoinedRoomMemberListRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(JoinedRoomMemberListRestServlet, self).__init__(hs)
- self.state = hs.get_state_handler()
+ self.message_handler = hs.get_message_handler()
@defer.inlineCallbacks
def on_GET(self, request, room_id):
- yield self.auth.get_user_by_req(request)
+ requester = yield self.auth.get_user_by_req(request)
- users_with_profile = yield self.state.get_current_user_in_room(room_id)
+ users_with_profile = yield self.message_handler.get_joined_members(
+ requester, room_id,
+ )
defer.returnValue((200, {
- "joined": {
- user_id: {
- "avatar_url": profile.avatar_url,
- "display_name": profile.display_name,
- }
- for user_id, profile in users_with_profile.iteritems()
- }
+ "joined": users_with_profile,
}))
@@ -423,7 +427,7 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomMessageListRestServlet, self).__init__(hs)
- self.handlers = hs.get_handlers()
+ self.pagination_handler = hs.get_pagination_handler()
@defer.inlineCallbacks
def on_GET(self, request, room_id):
@@ -432,14 +436,13 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
request, default_limit=10,
)
as_client_event = "raw" not in request.args
- filter_bytes = request.args.get("filter", None)
+ filter_bytes = parse_string(request, "filter")
if filter_bytes:
- filter_json = urllib.unquote(filter_bytes[-1]).decode("UTF-8")
+ filter_json = urlparse.unquote(filter_bytes).decode("UTF-8")
event_filter = Filter(json.loads(filter_json))
else:
event_filter = None
- handler = self.handlers.message_handler
- msgs = yield handler.get_messages(
+ msgs = yield self.pagination_handler.get_messages(
room_id=room_id,
requester=requester,
pagin_config=pagination_config,
@@ -456,14 +459,13 @@ class RoomStateRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomStateRestServlet, self).__init__(hs)
- self.handlers = hs.get_handlers()
+ self.message_handler = hs.get_message_handler()
@defer.inlineCallbacks
def on_GET(self, request, room_id):
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
- handler = self.handlers.message_handler
# Get all the current state for this room
- events = yield handler.get_state_events(
+ events = yield self.message_handler.get_state_events(
room_id=room_id,
user_id=requester.user.to_string(),
is_guest=requester.is_guest,
@@ -491,23 +493,45 @@ class RoomInitialSyncRestServlet(ClientV1RestServlet):
defer.returnValue((200, content))
-class RoomEventContext(ClientV1RestServlet):
+class RoomEventServlet(ClientV1RestServlet):
+ PATTERNS = client_path_patterns(
+ "/rooms/(?P<room_id>[^/]*)/event/(?P<event_id>[^/]*)$"
+ )
+
+ def __init__(self, hs):
+ super(RoomEventServlet, self).__init__(hs)
+ self.clock = hs.get_clock()
+ self.event_handler = hs.get_event_handler()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, room_id, event_id):
+ requester = yield self.auth.get_user_by_req(request)
+ event = yield self.event_handler.get_event(requester.user, event_id)
+
+ time_now = self.clock.time_msec()
+ if event:
+ defer.returnValue((200, serialize_event(event, time_now)))
+ else:
+ defer.returnValue((404, "Event not found."))
+
+
+class RoomEventContextServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns(
"/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$"
)
def __init__(self, hs):
- super(RoomEventContext, self).__init__(hs)
+ super(RoomEventContextServlet, self).__init__(hs)
self.clock = hs.get_clock()
- self.handlers = hs.get_handlers()
+ self.room_context_handler = hs.get_room_context_handler()
@defer.inlineCallbacks
def on_GET(self, request, room_id, event_id):
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
- limit = int(request.args.get("limit", [10])[0])
+ limit = parse_integer(request, "limit", default=10)
- results = yield self.handlers.room_context_handler.get_event_context(
+ results = yield self.room_context_handler.get_event_context(
requester.user,
room_id,
event_id,
@@ -537,7 +561,7 @@ class RoomEventContext(ClientV1RestServlet):
class RoomForgetRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomForgetRestServlet, self).__init__(hs)
- self.handlers = hs.get_handlers()
+ self.room_member_handler = hs.get_room_member_handler()
def register(self, http_server):
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/forget")
@@ -550,7 +574,7 @@ class RoomForgetRestServlet(ClientV1RestServlet):
allow_guest=False,
)
- yield self.handlers.room_member_handler.forget(
+ yield self.room_member_handler.forget(
user=requester.user,
room_id=room_id,
)
@@ -568,12 +592,12 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomMembershipRestServlet, self).__init__(hs)
- self.handlers = hs.get_handlers()
+ self.room_member_handler = hs.get_room_member_handler()
def register(self, http_server):
# /rooms/$roomid/[invite|join|leave]
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/"
- "(?P<membership_action>join|invite|leave|ban|unban|kick|forget)")
+ "(?P<membership_action>join|invite|leave|ban|unban|kick)")
register_txn_path(self, PATTERNS, http_server)
@defer.inlineCallbacks
@@ -591,13 +615,13 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
try:
content = parse_json_object_from_request(request)
- except:
+ except Exception:
# Turns out we used to ignore the body entirely, and some clients
# cheekily send invalid bodies.
content = {}
if membership_action == "invite" and self._has_3pid_invite_keys(content):
- yield self.handlers.room_member_handler.do_3pid_invite(
+ yield self.room_member_handler.do_3pid_invite(
room_id,
requester.user,
content["medium"],
@@ -611,15 +635,14 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
target = requester.user
if membership_action in ["invite", "ban", "unban", "kick"]:
- if "user_id" not in content:
- raise SynapseError(400, "Missing user_id key.")
+ assert_params_in_dict(content, ["user_id"])
target = UserID.from_string(content["user_id"])
event_content = None
if 'reason' in content and membership_action in ['kick', 'ban']:
event_content = {'reason': content['reason']}
- yield self.handlers.room_member_handler.update_membership(
+ yield self.room_member_handler.update_membership(
requester=requester,
target=target,
room_id=room_id,
@@ -629,7 +652,12 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
content=event_content,
)
- defer.returnValue((200, {}))
+ return_value = {}
+
+ if membership_action == "join":
+ return_value["room_id"] = room_id
+
+ defer.returnValue((200, return_value))
def _has_3pid_invite_keys(self, content):
for key in {"id_server", "medium", "address"}:
@@ -647,6 +675,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomRedactEventRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
+ self.event_creation_handler = hs.get_event_creation_handler()
def register(self, http_server):
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)")
@@ -657,8 +686,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
requester = yield self.auth.get_user_by_req(request)
content = parse_json_object_from_request(request)
- msg_handler = self.handlers.message_handler
- event = yield msg_handler.create_and_send_nonmember_event(
+ event = yield self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.Redaction,
@@ -692,8 +720,8 @@ class RoomTypingRestServlet(ClientV1RestServlet):
def on_PUT(self, request, room_id, user_id):
requester = yield self.auth.get_user_by_req(request)
- room_id = urllib.unquote(room_id)
- target_user = UserID.from_string(urllib.unquote(user_id))
+ room_id = urlparse.unquote(room_id)
+ target_user = UserID.from_string(urlparse.unquote(user_id))
content = parse_json_object_from_request(request)
@@ -734,7 +762,7 @@ class SearchRestServlet(ClientV1RestServlet):
content = parse_json_object_from_request(request)
- batch = request.args.get("next_batch", [None])[0]
+ batch = parse_string(request, "next_batch")
results = yield self.handlers.search_handler.search(
requester.user,
content,
@@ -802,9 +830,13 @@ def register_servlets(hs, http_server):
RoomSendEventRestServlet(hs).register(http_server)
PublicRoomListRestServlet(hs).register(http_server)
RoomStateRestServlet(hs).register(http_server)
- RoomInitialSyncRestServlet(hs).register(http_server)
RoomRedactEventRestServlet(hs).register(http_server)
RoomTypingRestServlet(hs).register(http_server)
SearchRestServlet(hs).register(http_server)
JoinedRoomsRestServlet(hs).register(http_server)
- RoomEventContext(hs).register(http_server)
+ RoomEventServlet(hs).register(http_server)
+ RoomEventContextServlet(hs).register(http_server)
+
+
+def register_deprecated_servlets(hs, http_server):
+ RoomInitialSyncRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py
index c43b30b73a..62f4c3d93e 100644
--- a/synapse/rest/client/v1/voip.py
+++ b/synapse/rest/client/v1/voip.py
@@ -13,16 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import base64
+import hashlib
+import hmac
+
from twisted.internet import defer
from .base import ClientV1RestServlet, client_path_patterns
-import hmac
-import hashlib
-import base64
-
-
class VoipRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/voip/turnServer$")
diff --git a/synapse/rest/client/v1_only/__init__.py b/synapse/rest/client/v1_only/__init__.py
new file mode 100644
index 0000000000..936f902ace
--- /dev/null
+++ b/synapse/rest/client/v1_only/__init__.py
@@ -0,0 +1,3 @@
+"""
+REST APIs that are only used in v1 (the legacy API).
+"""
diff --git a/synapse/rest/client/v1_only/base.py b/synapse/rest/client/v1_only/base.py
new file mode 100644
index 0000000000..9d4db7437c
--- /dev/null
+++ b/synapse/rest/client/v1_only/base.py
@@ -0,0 +1,39 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""This module contains base REST classes for constructing client v1 servlets.
+"""
+
+import re
+
+from synapse.api.urls import CLIENT_PREFIX
+
+
+def v1_only_client_path_patterns(path_regex, include_in_unstable=True):
+ """Creates a regex compiled client path with the correct client path
+ prefix.
+
+ Args:
+ path_regex (str): The regex string to match. This should NOT have a ^
+ as this will be prefixed.
+ Returns:
+ list of SRE_Pattern
+ """
+ patterns = [re.compile("^" + CLIENT_PREFIX + path_regex)]
+ if include_in_unstable:
+ unstable_prefix = CLIENT_PREFIX.replace("/api/v1", "/unstable")
+ patterns.append(re.compile("^" + unstable_prefix + path_regex))
+ return patterns
diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1_only/register.py
index ecf7e311a9..3439c3c6d4 100644
--- a/synapse/rest/client/v1/register.py
+++ b/synapse/rest/client/v1_only/register.py
@@ -14,21 +14,20 @@
# limitations under the License.
"""This module contains REST servlets to do with registration: /register"""
+import hmac
+import logging
+from hashlib import sha1
+
from twisted.internet import defer
-from synapse.api.errors import SynapseError, Codes
-from synapse.api.constants import LoginType
-from synapse.api.auth import get_access_token_from_request
-from .base import ClientV1RestServlet, client_path_patterns
import synapse.util.stringutils as stringutils
-from synapse.http.servlet import parse_json_object_from_request
+from synapse.api.constants import LoginType
+from synapse.api.errors import Codes, SynapseError
+from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_request
+from synapse.rest.client.v1.base import ClientV1RestServlet
from synapse.types import create_requester
-from synapse.util.async import run_on_reactor
-
-from hashlib import sha1
-import hmac
-import logging
+from .base import v1_only_client_path_patterns
logger = logging.getLogger(__name__)
@@ -51,7 +50,7 @@ class RegisterRestServlet(ClientV1RestServlet):
handler doesn't have a concept of multi-stages or sessions.
"""
- PATTERNS = client_path_patterns("/register$", releases=(), include_in_unstable=False)
+ PATTERNS = v1_only_client_path_patterns("/register$", include_in_unstable=False)
def __init__(self, hs):
"""
@@ -66,14 +65,20 @@ class RegisterRestServlet(ClientV1RestServlet):
# TODO: persistent storage
self.sessions = {}
self.enable_registration = hs.config.enable_registration
+ self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
self.handlers = hs.get_handlers()
def on_GET(self, request):
+
+ require_email = 'email' in self.hs.config.registrations_require_3pid
+ require_msisdn = 'msisdn' in self.hs.config.registrations_require_3pid
+
+ flows = []
if self.hs.config.enable_registration_captcha:
- return (
- 200,
- {"flows": [
+ # only support the email-only flow if we don't require MSISDN 3PIDs
+ if not require_msisdn:
+ flows.extend([
{
"type": LoginType.RECAPTCHA,
"stages": [
@@ -82,27 +87,34 @@ class RegisterRestServlet(ClientV1RestServlet):
LoginType.PASSWORD
]
},
+ ])
+ # only support 3PIDless registration if no 3PIDs are required
+ if not require_email and not require_msisdn:
+ flows.extend([
{
"type": LoginType.RECAPTCHA,
"stages": [LoginType.RECAPTCHA, LoginType.PASSWORD]
}
- ]}
- )
+ ])
else:
- return (
- 200,
- {"flows": [
+ # only support the email-only flow if we don't require MSISDN 3PIDs
+ if require_email or not require_msisdn:
+ flows.extend([
{
"type": LoginType.EMAIL_IDENTITY,
"stages": [
LoginType.EMAIL_IDENTITY, LoginType.PASSWORD
]
- },
+ }
+ ])
+ # only support 3PIDless registration if no 3PIDs are required
+ if not require_email and not require_msisdn:
+ flows.extend([
{
"type": LoginType.PASSWORD
}
- ]}
- )
+ ])
+ return (200, {"flows": flows})
@defer.inlineCallbacks
def on_POST(self, request):
@@ -111,8 +123,7 @@ class RegisterRestServlet(ClientV1RestServlet):
session = (register_json["session"]
if "session" in register_json else None)
login_type = None
- if "type" not in register_json:
- raise SynapseError(400, "Missing 'type' key.")
+ assert_params_in_dict(register_json, ["type"])
try:
login_type = register_json["type"]
@@ -258,7 +269,6 @@ class RegisterRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def _do_password(self, request, register_json, session):
- yield run_on_reactor()
if (self.hs.config.enable_registration_captcha and
not session[LoginType.RECAPTCHA]):
# captcha should've been done by this stage!
@@ -298,11 +308,9 @@ class RegisterRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def _do_app_service(self, request, register_json, session):
- as_token = get_access_token_from_request(request)
-
- if "user" not in register_json:
- raise SynapseError(400, "Expected 'user' key.")
+ as_token = self.auth.get_access_token_from_request(request)
+ assert_params_in_dict(register_json, ["user"])
user_localpart = register_json["user"].encode("utf-8")
handler = self.handlers.registration_handler
@@ -319,14 +327,7 @@ class RegisterRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def _do_shared_secret(self, request, register_json, session):
- yield run_on_reactor()
-
- if not isinstance(register_json.get("mac", None), basestring):
- raise SynapseError(400, "Expected mac.")
- if not isinstance(register_json.get("user", None), basestring):
- raise SynapseError(400, "Expected 'user' key.")
- if not isinstance(register_json.get("password", None), basestring):
- raise SynapseError(400, "Expected 'password' key.")
+ assert_params_in_dict(register_json, ["mac", "user", "password"])
if not self.hs.config.registration_shared_secret:
raise SynapseError(400, "Shared secret registration is not enabled")
@@ -336,9 +337,9 @@ class RegisterRestServlet(ClientV1RestServlet):
admin = register_json.get("admin", None)
# Its important to check as we use null bytes as HMAC field separators
- if "\x00" in user:
+ if b"\x00" in user:
raise SynapseError(400, "Invalid user")
- if "\x00" in password:
+ if b"\x00" in password:
raise SynapseError(400, "Invalid password")
# str() because otherwise hmac complains that 'unicode' does not
@@ -346,20 +347,20 @@ class RegisterRestServlet(ClientV1RestServlet):
got_mac = str(register_json["mac"])
want_mac = hmac.new(
- key=self.hs.config.registration_shared_secret,
+ key=self.hs.config.registration_shared_secret.encode(),
digestmod=sha1,
)
want_mac.update(user)
- want_mac.update("\x00")
+ want_mac.update(b"\x00")
want_mac.update(password)
- want_mac.update("\x00")
- want_mac.update("admin" if admin else "notadmin")
+ want_mac.update(b"\x00")
+ want_mac.update(b"admin" if admin else b"notadmin")
want_mac = want_mac.hexdigest()
if compare_digest(want_mac, got_mac):
handler = self.handlers.registration_handler
user_id, token = yield handler.register(
- localpart=user,
+ localpart=user.lower(),
password=password,
admin=bool(admin),
)
@@ -379,7 +380,7 @@ class CreateUserRestServlet(ClientV1RestServlet):
"""Handles user creation via a server-to-server interface
"""
- PATTERNS = client_path_patterns("/createUser$", releases=())
+ PATTERNS = v1_only_client_path_patterns("/createUser$")
def __init__(self, hs):
super(CreateUserRestServlet, self).__init__(hs)
@@ -390,7 +391,7 @@ class CreateUserRestServlet(ClientV1RestServlet):
def on_POST(self, request):
user_json = parse_json_object_from_request(request)
- access_token = get_access_token_from_request(request)
+ access_token = self.auth.get_access_token_from_request(request)
app_service = self.store.get_app_service_by_token(
access_token
)
@@ -409,13 +410,7 @@ class CreateUserRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def _do_create(self, requester, user_json):
- yield run_on_reactor()
-
- if "localpart" not in user_json:
- raise SynapseError(400, "Expected 'localpart' key.")
-
- if "displayname" not in user_json:
- raise SynapseError(400, "Expected 'displayname' key.")
+ assert_params_in_dict(user_json, ["localpart", "displayname"])
localpart = user_json["localpart"].encode("utf-8")
displayname = user_json["displayname"].encode("utf-8")
diff --git a/synapse/rest/client/v2_alpha/_base.py b/synapse/rest/client/v2_alpha/_base.py
index 1f5bc24cc3..77434937ff 100644
--- a/synapse/rest/client/v2_alpha/_base.py
+++ b/synapse/rest/client/v2_alpha/_base.py
@@ -15,12 +15,13 @@
"""This module contains base REST classes for constructing client v1 servlets.
"""
-
-from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX
+import logging
import re
-import logging
+from twisted.internet import defer
+from synapse.api.errors import InteractiveAuthIncompleteError
+from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX
logger = logging.getLogger(__name__)
@@ -57,3 +58,37 @@ def set_timeline_upper_limit(filter_json, filter_timeline_limit):
filter_json['room']['timeline']["limit"] = min(
filter_json['room']['timeline']['limit'],
filter_timeline_limit)
+
+
+def interactive_auth_handler(orig):
+ """Wraps an on_POST method to handle InteractiveAuthIncompleteErrors
+
+ Takes a on_POST method which returns a deferred (errcode, body) response
+ and adds exception handling to turn a InteractiveAuthIncompleteError into
+ a 401 response.
+
+ Normal usage is:
+
+ @interactive_auth_handler
+ @defer.inlineCallbacks
+ def on_POST(self, request):
+ # ...
+ yield self.auth_handler.check_auth
+ """
+ def wrapped(*args, **kwargs):
+ res = defer.maybeDeferred(orig, *args, **kwargs)
+ res.addErrback(_catch_incomplete_interactive_auth)
+ return res
+ return wrapped
+
+
+def _catch_incomplete_interactive_auth(f):
+ """helper for interactive_auth_handler
+
+ Catches InteractiveAuthIncompleteErrors and turns them into 401 responses
+
+ Args:
+ f (failure.Failure):
+ """
+ f.trap(InteractiveAuthIncompleteError)
+ return 401, f.value.result
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index 4990b22b9f..eeae466d82 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,21 +14,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+
+from six.moves import http_client
from twisted.internet import defer
from synapse.api.constants import LoginType
-from synapse.api.errors import LoginError, SynapseError, Codes
+from synapse.api.errors import Codes, SynapseError
from synapse.http.servlet import (
- RestServlet, parse_json_object_from_request, assert_params_in_request
+ RestServlet,
+ assert_params_in_dict,
+ parse_json_object_from_request,
)
-from synapse.util.async import run_on_reactor
from synapse.util.msisdn import phone_number_to_msisdn
+from synapse.util.threepids import check_3pid_allowed
-from ._base import client_v2_patterns
-
-import logging
-
+from ._base import client_v2_patterns, interactive_auth_handler
logger = logging.getLogger(__name__)
@@ -44,10 +47,15 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
def on_POST(self, request):
body = parse_json_object_from_request(request)
- assert_params_in_request(body, [
+ assert_params_in_dict(body, [
'id_server', 'client_secret', 'email', 'send_attempt'
])
+ if not check_3pid_allowed(self.hs, "email", body['email']):
+ raise SynapseError(
+ 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
+ )
+
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
'email', body['email']
)
@@ -72,13 +80,18 @@ class MsisdnPasswordRequestTokenRestServlet(RestServlet):
def on_POST(self, request):
body = parse_json_object_from_request(request)
- assert_params_in_request(body, [
+ assert_params_in_dict(body, [
'id_server', 'client_secret',
'country', 'phone_number', 'send_attempt',
])
msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
+ if not check_3pid_allowed(self.hs, "msisdn", msisdn):
+ raise SynapseError(
+ 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
+ )
+
existingUid = yield self.datastore.get_user_id_by_threepid(
'msisdn', msisdn
)
@@ -99,56 +112,60 @@ class PasswordRestServlet(RestServlet):
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
self.datastore = self.hs.get_datastore()
+ self._set_password_handler = hs.get_set_password_handler()
+ @interactive_auth_handler
@defer.inlineCallbacks
def on_POST(self, request):
- yield run_on_reactor()
-
body = parse_json_object_from_request(request)
- authed, result, params, _ = yield self.auth_handler.check_auth([
- [LoginType.PASSWORD],
- [LoginType.EMAIL_IDENTITY],
- [LoginType.MSISDN],
- ], body, self.hs.get_ip_from_request(request))
-
- if not authed:
- defer.returnValue((401, result))
-
- user_id = None
- requester = None
-
- if LoginType.PASSWORD in result:
- # if using password, they should also be logged in
+ # there are two possibilities here. Either the user does not have an
+ # access token, and needs to do a password reset; or they have one and
+ # need to validate their identity.
+ #
+ # In the first case, we offer a couple of means of identifying
+ # themselves (email and msisdn, though it's unclear if msisdn actually
+ # works).
+ #
+ # In the second case, we require a password to confirm their identity.
+
+ if self.auth.has_access_token(request):
requester = yield self.auth.get_user_by_req(request)
- user_id = requester.user.to_string()
- if user_id != result[LoginType.PASSWORD]:
- raise LoginError(400, "", Codes.UNKNOWN)
- elif LoginType.EMAIL_IDENTITY in result:
- threepid = result[LoginType.EMAIL_IDENTITY]
- if 'medium' not in threepid or 'address' not in threepid:
- raise SynapseError(500, "Malformed threepid")
- if threepid['medium'] == 'email':
- # For emails, transform the address to lowercase.
- # We store all email addreses as lowercase in the DB.
- # (See add_threepid in synapse/handlers/auth.py)
- threepid['address'] = threepid['address'].lower()
- # if using email, we must know about the email they're authing with!
- threepid_user_id = yield self.datastore.get_user_id_by_threepid(
- threepid['medium'], threepid['address']
+ params = yield self.auth_handler.validate_user_via_ui_auth(
+ requester, body, self.hs.get_ip_from_request(request),
)
- if not threepid_user_id:
- raise SynapseError(404, "Email address not found", Codes.NOT_FOUND)
- user_id = threepid_user_id
+ user_id = requester.user.to_string()
else:
- logger.error("Auth succeeded but no known type!", result.keys())
- raise SynapseError(500, "", Codes.UNKNOWN)
+ requester = None
+ result, params, _ = yield self.auth_handler.check_auth(
+ [[LoginType.EMAIL_IDENTITY], [LoginType.MSISDN]],
+ body, self.hs.get_ip_from_request(request),
+ )
- if 'new_password' not in params:
- raise SynapseError(400, "", Codes.MISSING_PARAM)
+ if LoginType.EMAIL_IDENTITY in result:
+ threepid = result[LoginType.EMAIL_IDENTITY]
+ if 'medium' not in threepid or 'address' not in threepid:
+ raise SynapseError(500, "Malformed threepid")
+ if threepid['medium'] == 'email':
+ # For emails, transform the address to lowercase.
+ # We store all email addreses as lowercase in the DB.
+ # (See add_threepid in synapse/handlers/auth.py)
+ threepid['address'] = threepid['address'].lower()
+ # if using email, we must know about the email they're authing with!
+ threepid_user_id = yield self.datastore.get_user_id_by_threepid(
+ threepid['medium'], threepid['address']
+ )
+ if not threepid_user_id:
+ raise SynapseError(404, "Email address not found", Codes.NOT_FOUND)
+ user_id = threepid_user_id
+ else:
+ logger.error("Auth succeeded but no known type! %r", result.keys())
+ raise SynapseError(500, "", Codes.UNKNOWN)
+
+ assert_params_in_dict(params, ["new_password"])
new_password = params['new_password']
- yield self.auth_handler.set_password(
+ yield self._set_password_handler.set_password(
user_id, new_password, requester
)
@@ -162,42 +179,39 @@ class DeactivateAccountRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/deactivate$")
def __init__(self, hs):
+ super(DeactivateAccountRestServlet, self).__init__()
self.hs = hs
- self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
- super(DeactivateAccountRestServlet, self).__init__()
+ self._deactivate_account_handler = hs.get_deactivate_account_handler()
+ @interactive_auth_handler
@defer.inlineCallbacks
def on_POST(self, request):
body = parse_json_object_from_request(request)
+ erase = body.get("erase", False)
+ if not isinstance(erase, bool):
+ raise SynapseError(
+ http_client.BAD_REQUEST,
+ "Param 'erase' must be a boolean, if given",
+ Codes.BAD_JSON,
+ )
- authed, result, params, _ = yield self.auth_handler.check_auth([
- [LoginType.PASSWORD],
- ], body, self.hs.get_ip_from_request(request))
-
- if not authed:
- defer.returnValue((401, result))
-
- user_id = None
- requester = None
-
- if LoginType.PASSWORD in result:
- # if using password, they should also be logged in
- requester = yield self.auth.get_user_by_req(request)
- user_id = requester.user.to_string()
- if user_id != result[LoginType.PASSWORD]:
- raise LoginError(400, "", Codes.UNKNOWN)
- else:
- logger.error("Auth succeeded but no known type!", result.keys())
- raise SynapseError(500, "", Codes.UNKNOWN)
+ requester = yield self.auth.get_user_by_req(request)
- # FIXME: Theoretically there is a race here wherein user resets password
- # using threepid.
- yield self.store.user_delete_access_tokens(user_id)
- yield self.store.user_delete_threepids(user_id)
- yield self.store.user_set_password_hash(user_id, None)
+ # allow ASes to dectivate their own users
+ if requester.app_service:
+ yield self._deactivate_account_handler.deactivate_account(
+ requester.user.to_string(), erase,
+ )
+ defer.returnValue((200, {}))
+ yield self.auth_handler.validate_user_via_ui_auth(
+ requester, body, self.hs.get_ip_from_request(request),
+ )
+ yield self._deactivate_account_handler.deactivate_account(
+ requester.user.to_string(), erase,
+ )
defer.returnValue((200, {}))
@@ -213,15 +227,15 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
@defer.inlineCallbacks
def on_POST(self, request):
body = parse_json_object_from_request(request)
+ assert_params_in_dict(
+ body,
+ ['id_server', 'client_secret', 'email', 'send_attempt'],
+ )
- required = ['id_server', 'client_secret', 'email', 'send_attempt']
- absent = []
- for k in required:
- if k not in body:
- absent.append(k)
-
- if absent:
- raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
+ if not check_3pid_allowed(self.hs, "email", body['email']):
+ raise SynapseError(
+ 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
+ )
existingUid = yield self.datastore.get_user_id_by_threepid(
'email', body['email']
@@ -246,21 +260,18 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
@defer.inlineCallbacks
def on_POST(self, request):
body = parse_json_object_from_request(request)
-
- required = [
+ assert_params_in_dict(body, [
'id_server', 'client_secret',
'country', 'phone_number', 'send_attempt',
- ]
- absent = []
- for k in required:
- if k not in body:
- absent.append(k)
-
- if absent:
- raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
+ ])
msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
+ if not check_3pid_allowed(self.hs, "msisdn", msisdn):
+ raise SynapseError(
+ 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
+ )
+
existingUid = yield self.datastore.get_user_id_by_threepid(
'msisdn', msisdn
)
@@ -285,8 +296,6 @@ class ThreepidRestServlet(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request):
- yield run_on_reactor()
-
requester = yield self.auth.get_user_by_req(request)
threepids = yield self.datastore.user_get_threepids(
@@ -297,8 +306,6 @@ class ThreepidRestServlet(RestServlet):
@defer.inlineCallbacks
def on_POST(self, request):
- yield run_on_reactor()
-
body = parse_json_object_from_request(request)
threePidCreds = body.get('threePidCreds')
@@ -350,29 +357,40 @@ class ThreepidDeleteRestServlet(RestServlet):
@defer.inlineCallbacks
def on_POST(self, request):
- yield run_on_reactor()
-
body = parse_json_object_from_request(request)
-
- required = ['medium', 'address']
- absent = []
- for k in required:
- if k not in body:
- absent.append(k)
-
- if absent:
- raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
+ assert_params_in_dict(body, ['medium', 'address'])
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
- yield self.auth_handler.delete_threepid(
- user_id, body['medium'], body['address']
- )
+ try:
+ yield self.auth_handler.delete_threepid(
+ user_id, body['medium'], body['address']
+ )
+ except Exception:
+ # NB. This endpoint should succeed if there is nothing to
+ # delete, so it should only throw if something is wrong
+ # that we ought to care about.
+ logger.exception("Failed to remove threepid")
+ raise SynapseError(500, "Failed to remove threepid")
defer.returnValue((200, {}))
+class WhoamiRestServlet(RestServlet):
+ PATTERNS = client_v2_patterns("/account/whoami$")
+
+ def __init__(self, hs):
+ super(WhoamiRestServlet, self).__init__()
+ self.auth = hs.get_auth()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request):
+ requester = yield self.auth.get_user_by_req(request)
+
+ defer.returnValue((200, {'user_id': requester.user.to_string()}))
+
+
def register_servlets(hs, http_server):
EmailPasswordRequestTokenRestServlet(hs).register(http_server)
MsisdnPasswordRequestTokenRestServlet(hs).register(http_server)
@@ -382,3 +400,4 @@ def register_servlets(hs, http_server):
MsisdnThreepidRequestTokenRestServlet(hs).register(http_server)
ThreepidRestServlet(hs).register(http_server)
ThreepidDeleteRestServlet(hs).register(http_server)
+ WhoamiRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py
index 0e0a187efd..371e9aa354 100644
--- a/synapse/rest/client/v2_alpha/account_data.py
+++ b/synapse/rest/client/v2_alpha/account_data.py
@@ -13,14 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import client_v2_patterns
-
-from synapse.http.servlet import RestServlet, parse_json_object_from_request
-from synapse.api.errors import AuthError, SynapseError
+import logging
from twisted.internet import defer
-import logging
+from synapse.api.errors import AuthError, SynapseError
+from synapse.http.servlet import RestServlet, parse_json_object_from_request
+
+from ._base import client_v2_patterns
logger = logging.getLogger(__name__)
diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py
index 8e5577148f..bd8b5f4afa 100644
--- a/synapse/rest/client/v2_alpha/auth.py
+++ b/synapse/rest/client/v2_alpha/auth.py
@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+
from twisted.internet import defer
from synapse.api.constants import LoginType
@@ -23,9 +25,6 @@ from synapse.http.servlet import RestServlet
from ._base import client_v2_patterns
-import logging
-
-
logger = logging.getLogger(__name__)
RECAPTCHA_TEMPLATE = """
@@ -129,7 +128,6 @@ class AuthRestServlet(RestServlet):
html_bytes = html.encode("utf8")
request.setResponseCode(200)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
- request.setHeader(b"Server", self.hs.version_string)
request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
request.write(html_bytes)
@@ -175,7 +173,6 @@ class AuthRestServlet(RestServlet):
html_bytes = html.encode("utf8")
request.setResponseCode(200)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
- request.setHeader(b"Server", self.hs.version_string)
request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
request.write(html_bytes)
diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py
index b57ba95d24..9b75bb1377 100644
--- a/synapse/rest/client/v2_alpha/devices.py
+++ b/synapse/rest/client/v2_alpha/devices.py
@@ -17,15 +17,20 @@ import logging
from twisted.internet import defer
-from synapse.api import constants, errors
-from synapse.http import servlet
-from ._base import client_v2_patterns
+from synapse.api import errors
+from synapse.http.servlet import (
+ RestServlet,
+ assert_params_in_dict,
+ parse_json_object_from_request,
+)
+
+from ._base import client_v2_patterns, interactive_auth_handler
logger = logging.getLogger(__name__)
-class DevicesRestServlet(servlet.RestServlet):
- PATTERNS = client_v2_patterns("/devices$", releases=[], v2_alpha=False)
+class DevicesRestServlet(RestServlet):
+ PATTERNS = client_v2_patterns("/devices$", v2_alpha=False)
def __init__(self, hs):
"""
@@ -46,12 +51,12 @@ class DevicesRestServlet(servlet.RestServlet):
defer.returnValue((200, {"devices": devices}))
-class DeleteDevicesRestServlet(servlet.RestServlet):
+class DeleteDevicesRestServlet(RestServlet):
"""
API for bulk deletion of devices. Accepts a JSON object with a devices
key which lists the device_ids to delete. Requires user interactive auth.
"""
- PATTERNS = client_v2_patterns("/delete_devices", releases=[], v2_alpha=False)
+ PATTERNS = client_v2_patterns("/delete_devices", v2_alpha=False)
def __init__(self, hs):
super(DeleteDevicesRestServlet, self).__init__()
@@ -60,31 +65,28 @@ class DeleteDevicesRestServlet(servlet.RestServlet):
self.device_handler = hs.get_device_handler()
self.auth_handler = hs.get_auth_handler()
+ @interactive_auth_handler
@defer.inlineCallbacks
def on_POST(self, request):
+ requester = yield self.auth.get_user_by_req(request)
+
try:
- body = servlet.parse_json_object_from_request(request)
+ body = parse_json_object_from_request(request)
except errors.SynapseError as e:
if e.errcode == errors.Codes.NOT_JSON:
- # deal with older clients which didn't pass a J*DELETESON dict
+ # DELETE
+ # deal with older clients which didn't pass a JSON dict
# the same as those that pass an empty dict
body = {}
else:
raise e
- if 'devices' not in body:
- raise errors.SynapseError(
- 400, "No devices supplied", errcode=errors.Codes.MISSING_PARAM
- )
+ assert_params_in_dict(body, ["devices"])
- authed, result, params, _ = yield self.auth_handler.check_auth([
- [constants.LoginType.PASSWORD],
- ], body, self.hs.get_ip_from_request(request))
-
- if not authed:
- defer.returnValue((401, result))
+ yield self.auth_handler.validate_user_via_ui_auth(
+ requester, body, self.hs.get_ip_from_request(request),
+ )
- requester = yield self.auth.get_user_by_req(request)
yield self.device_handler.delete_devices(
requester.user.to_string(),
body['devices'],
@@ -92,9 +94,8 @@ class DeleteDevicesRestServlet(servlet.RestServlet):
defer.returnValue((200, {}))
-class DeviceRestServlet(servlet.RestServlet):
- PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$",
- releases=[], v2_alpha=False)
+class DeviceRestServlet(RestServlet):
+ PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$", v2_alpha=False)
def __init__(self, hs):
"""
@@ -116,10 +117,13 @@ class DeviceRestServlet(servlet.RestServlet):
)
defer.returnValue((200, device))
+ @interactive_auth_handler
@defer.inlineCallbacks
def on_DELETE(self, request, device_id):
+ requester = yield self.auth.get_user_by_req(request)
+
try:
- body = servlet.parse_json_object_from_request(request)
+ body = parse_json_object_from_request(request)
except errors.SynapseError as e:
if e.errcode == errors.Codes.NOT_JSON:
@@ -129,17 +133,12 @@ class DeviceRestServlet(servlet.RestServlet):
else:
raise
- authed, result, params, _ = yield self.auth_handler.check_auth([
- [constants.LoginType.PASSWORD],
- ], body, self.hs.get_ip_from_request(request))
-
- if not authed:
- defer.returnValue((401, result))
+ yield self.auth_handler.validate_user_via_ui_auth(
+ requester, body, self.hs.get_ip_from_request(request),
+ )
- requester = yield self.auth.get_user_by_req(request)
yield self.device_handler.delete_device(
- requester.user.to_string(),
- device_id,
+ requester.user.to_string(), device_id,
)
defer.returnValue((200, {}))
@@ -147,7 +146,7 @@ class DeviceRestServlet(servlet.RestServlet):
def on_PUT(self, request, device_id):
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
- body = servlet.parse_json_object_from_request(request)
+ body = parse_json_object_from_request(request)
yield self.device_handler.update_device(
requester.user.to_string(),
device_id,
diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py
index d2b2fd66e6..ae86728879 100644
--- a/synapse/rest/client/v2_alpha/filter.py
+++ b/synapse/rest/client/v2_alpha/filter.py
@@ -13,17 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+
from twisted.internet import defer
-from synapse.api.errors import AuthError, SynapseError, StoreError, Codes
+from synapse.api.errors import AuthError, Codes, StoreError, SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.types import UserID
-from ._base import client_v2_patterns
-from ._base import set_timeline_upper_limit
-
-import logging
-
+from ._base import client_v2_patterns, set_timeline_upper_limit
logger = logging.getLogger(__name__)
@@ -50,7 +48,7 @@ class GetFilterRestServlet(RestServlet):
try:
filter_id = int(filter_id)
- except:
+ except Exception:
raise SynapseError(400, "Invalid filter_id")
try:
diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py
new file mode 100644
index 0000000000..21e02c07c0
--- /dev/null
+++ b/synapse/rest/client/v2_alpha/groups.py
@@ -0,0 +1,786 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from twisted.internet import defer
+
+from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.types import GroupID
+
+from ._base import client_v2_patterns
+
+logger = logging.getLogger(__name__)
+
+
+class GroupServlet(RestServlet):
+ """Get the group profile
+ """
+ PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/profile$")
+
+ def __init__(self, hs):
+ super(GroupServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, group_id):
+ requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ requester_user_id = requester.user.to_string()
+
+ group_description = yield self.groups_handler.get_group_profile(
+ group_id,
+ requester_user_id,
+ )
+
+ defer.returnValue((200, group_description))
+
+ @defer.inlineCallbacks
+ def on_POST(self, request, group_id):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ content = parse_json_object_from_request(request)
+ yield self.groups_handler.update_group_profile(
+ group_id, requester_user_id, content,
+ )
+
+ defer.returnValue((200, {}))
+
+
+class GroupSummaryServlet(RestServlet):
+ """Get the full group summary
+ """
+ PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/summary$")
+
+ def __init__(self, hs):
+ super(GroupSummaryServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, group_id):
+ requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ requester_user_id = requester.user.to_string()
+
+ get_group_summary = yield self.groups_handler.get_group_summary(
+ group_id,
+ requester_user_id,
+ )
+
+ defer.returnValue((200, get_group_summary))
+
+
+class GroupSummaryRoomsCatServlet(RestServlet):
+ """Update/delete a rooms entry in the summary.
+
+ Matches both:
+ - /groups/:group/summary/rooms/:room_id
+ - /groups/:group/summary/categories/:category/rooms/:room_id
+ """
+ PATTERNS = client_v2_patterns(
+ "/groups/(?P<group_id>[^/]*)/summary"
+ "(/categories/(?P<category_id>[^/]+))?"
+ "/rooms/(?P<room_id>[^/]*)$"
+ )
+
+ def __init__(self, hs):
+ super(GroupSummaryRoomsCatServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_PUT(self, request, group_id, category_id, room_id):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ content = parse_json_object_from_request(request)
+ resp = yield self.groups_handler.update_group_summary_room(
+ group_id, requester_user_id,
+ room_id=room_id,
+ category_id=category_id,
+ content=content,
+ )
+
+ defer.returnValue((200, resp))
+
+ @defer.inlineCallbacks
+ def on_DELETE(self, request, group_id, category_id, room_id):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ resp = yield self.groups_handler.delete_group_summary_room(
+ group_id, requester_user_id,
+ room_id=room_id,
+ category_id=category_id,
+ )
+
+ defer.returnValue((200, resp))
+
+
+class GroupCategoryServlet(RestServlet):
+ """Get/add/update/delete a group category
+ """
+ PATTERNS = client_v2_patterns(
+ "/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)$"
+ )
+
+ def __init__(self, hs):
+ super(GroupCategoryServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, group_id, category_id):
+ requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ requester_user_id = requester.user.to_string()
+
+ category = yield self.groups_handler.get_group_category(
+ group_id, requester_user_id,
+ category_id=category_id,
+ )
+
+ defer.returnValue((200, category))
+
+ @defer.inlineCallbacks
+ def on_PUT(self, request, group_id, category_id):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ content = parse_json_object_from_request(request)
+ resp = yield self.groups_handler.update_group_category(
+ group_id, requester_user_id,
+ category_id=category_id,
+ content=content,
+ )
+
+ defer.returnValue((200, resp))
+
+ @defer.inlineCallbacks
+ def on_DELETE(self, request, group_id, category_id):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ resp = yield self.groups_handler.delete_group_category(
+ group_id, requester_user_id,
+ category_id=category_id,
+ )
+
+ defer.returnValue((200, resp))
+
+
+class GroupCategoriesServlet(RestServlet):
+ """Get all group categories
+ """
+ PATTERNS = client_v2_patterns(
+ "/groups/(?P<group_id>[^/]*)/categories/$"
+ )
+
+ def __init__(self, hs):
+ super(GroupCategoriesServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, group_id):
+ requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ requester_user_id = requester.user.to_string()
+
+ category = yield self.groups_handler.get_group_categories(
+ group_id, requester_user_id,
+ )
+
+ defer.returnValue((200, category))
+
+
+class GroupRoleServlet(RestServlet):
+ """Get/add/update/delete a group role
+ """
+ PATTERNS = client_v2_patterns(
+ "/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)$"
+ )
+
+ def __init__(self, hs):
+ super(GroupRoleServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, group_id, role_id):
+ requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ requester_user_id = requester.user.to_string()
+
+ category = yield self.groups_handler.get_group_role(
+ group_id, requester_user_id,
+ role_id=role_id,
+ )
+
+ defer.returnValue((200, category))
+
+ @defer.inlineCallbacks
+ def on_PUT(self, request, group_id, role_id):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ content = parse_json_object_from_request(request)
+ resp = yield self.groups_handler.update_group_role(
+ group_id, requester_user_id,
+ role_id=role_id,
+ content=content,
+ )
+
+ defer.returnValue((200, resp))
+
+ @defer.inlineCallbacks
+ def on_DELETE(self, request, group_id, role_id):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ resp = yield self.groups_handler.delete_group_role(
+ group_id, requester_user_id,
+ role_id=role_id,
+ )
+
+ defer.returnValue((200, resp))
+
+
+class GroupRolesServlet(RestServlet):
+ """Get all group roles
+ """
+ PATTERNS = client_v2_patterns(
+ "/groups/(?P<group_id>[^/]*)/roles/$"
+ )
+
+ def __init__(self, hs):
+ super(GroupRolesServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, group_id):
+ requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ requester_user_id = requester.user.to_string()
+
+ category = yield self.groups_handler.get_group_roles(
+ group_id, requester_user_id,
+ )
+
+ defer.returnValue((200, category))
+
+
+class GroupSummaryUsersRoleServlet(RestServlet):
+ """Update/delete a user's entry in the summary.
+
+ Matches both:
+ - /groups/:group/summary/users/:room_id
+ - /groups/:group/summary/roles/:role/users/:user_id
+ """
+ PATTERNS = client_v2_patterns(
+ "/groups/(?P<group_id>[^/]*)/summary"
+ "(/roles/(?P<role_id>[^/]+))?"
+ "/users/(?P<user_id>[^/]*)$"
+ )
+
+ def __init__(self, hs):
+ super(GroupSummaryUsersRoleServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_PUT(self, request, group_id, role_id, user_id):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ content = parse_json_object_from_request(request)
+ resp = yield self.groups_handler.update_group_summary_user(
+ group_id, requester_user_id,
+ user_id=user_id,
+ role_id=role_id,
+ content=content,
+ )
+
+ defer.returnValue((200, resp))
+
+ @defer.inlineCallbacks
+ def on_DELETE(self, request, group_id, role_id, user_id):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ resp = yield self.groups_handler.delete_group_summary_user(
+ group_id, requester_user_id,
+ user_id=user_id,
+ role_id=role_id,
+ )
+
+ defer.returnValue((200, resp))
+
+
+class GroupRoomServlet(RestServlet):
+ """Get all rooms in a group
+ """
+ PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/rooms$")
+
+ def __init__(self, hs):
+ super(GroupRoomServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, group_id):
+ requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ requester_user_id = requester.user.to_string()
+
+ result = yield self.groups_handler.get_rooms_in_group(group_id, requester_user_id)
+
+ defer.returnValue((200, result))
+
+
+class GroupUsersServlet(RestServlet):
+ """Get all users in a group
+ """
+ PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/users$")
+
+ def __init__(self, hs):
+ super(GroupUsersServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, group_id):
+ requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ requester_user_id = requester.user.to_string()
+
+ result = yield self.groups_handler.get_users_in_group(group_id, requester_user_id)
+
+ defer.returnValue((200, result))
+
+
+class GroupInvitedUsersServlet(RestServlet):
+ """Get users invited to a group
+ """
+ PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/invited_users$")
+
+ def __init__(self, hs):
+ super(GroupInvitedUsersServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, group_id):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ result = yield self.groups_handler.get_invited_users_in_group(
+ group_id,
+ requester_user_id,
+ )
+
+ defer.returnValue((200, result))
+
+
+class GroupSettingJoinPolicyServlet(RestServlet):
+ """Set group join policy
+ """
+ PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/settings/m.join_policy$")
+
+ def __init__(self, hs):
+ super(GroupSettingJoinPolicyServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_PUT(self, request, group_id):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ content = parse_json_object_from_request(request)
+
+ result = yield self.groups_handler.set_group_join_policy(
+ group_id,
+ requester_user_id,
+ content,
+ )
+
+ defer.returnValue((200, result))
+
+
+class GroupCreateServlet(RestServlet):
+ """Create a group
+ """
+ PATTERNS = client_v2_patterns("/create_group$")
+
+ def __init__(self, hs):
+ super(GroupCreateServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+ self.server_name = hs.hostname
+
+ @defer.inlineCallbacks
+ def on_POST(self, request):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ # TODO: Create group on remote server
+ content = parse_json_object_from_request(request)
+ localpart = content.pop("localpart")
+ group_id = GroupID(localpart, self.server_name).to_string()
+
+ result = yield self.groups_handler.create_group(
+ group_id,
+ requester_user_id,
+ content,
+ )
+
+ defer.returnValue((200, result))
+
+
+class GroupAdminRoomsServlet(RestServlet):
+ """Add a room to the group
+ """
+ PATTERNS = client_v2_patterns(
+ "/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)$"
+ )
+
+ def __init__(self, hs):
+ super(GroupAdminRoomsServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_PUT(self, request, group_id, room_id):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ content = parse_json_object_from_request(request)
+ result = yield self.groups_handler.add_room_to_group(
+ group_id, requester_user_id, room_id, content,
+ )
+
+ defer.returnValue((200, result))
+
+ @defer.inlineCallbacks
+ def on_DELETE(self, request, group_id, room_id):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ result = yield self.groups_handler.remove_room_from_group(
+ group_id, requester_user_id, room_id,
+ )
+
+ defer.returnValue((200, result))
+
+
+class GroupAdminRoomsConfigServlet(RestServlet):
+ """Update the config of a room in a group
+ """
+ PATTERNS = client_v2_patterns(
+ "/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)"
+ "/config/(?P<config_key>[^/]*)$"
+ )
+
+ def __init__(self, hs):
+ super(GroupAdminRoomsConfigServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_PUT(self, request, group_id, room_id, config_key):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ content = parse_json_object_from_request(request)
+ result = yield self.groups_handler.update_room_in_group(
+ group_id, requester_user_id, room_id, config_key, content,
+ )
+
+ defer.returnValue((200, result))
+
+
+class GroupAdminUsersInviteServlet(RestServlet):
+ """Invite a user to the group
+ """
+ PATTERNS = client_v2_patterns(
+ "/groups/(?P<group_id>[^/]*)/admin/users/invite/(?P<user_id>[^/]*)$"
+ )
+
+ def __init__(self, hs):
+ super(GroupAdminUsersInviteServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+ self.store = hs.get_datastore()
+ self.is_mine_id = hs.is_mine_id
+
+ @defer.inlineCallbacks
+ def on_PUT(self, request, group_id, user_id):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ content = parse_json_object_from_request(request)
+ config = content.get("config", {})
+ result = yield self.groups_handler.invite(
+ group_id, user_id, requester_user_id, config,
+ )
+
+ defer.returnValue((200, result))
+
+
+class GroupAdminUsersKickServlet(RestServlet):
+ """Kick a user from the group
+ """
+ PATTERNS = client_v2_patterns(
+ "/groups/(?P<group_id>[^/]*)/admin/users/remove/(?P<user_id>[^/]*)$"
+ )
+
+ def __init__(self, hs):
+ super(GroupAdminUsersKickServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_PUT(self, request, group_id, user_id):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ content = parse_json_object_from_request(request)
+ result = yield self.groups_handler.remove_user_from_group(
+ group_id, user_id, requester_user_id, content,
+ )
+
+ defer.returnValue((200, result))
+
+
+class GroupSelfLeaveServlet(RestServlet):
+ """Leave a joined group
+ """
+ PATTERNS = client_v2_patterns(
+ "/groups/(?P<group_id>[^/]*)/self/leave$"
+ )
+
+ def __init__(self, hs):
+ super(GroupSelfLeaveServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_PUT(self, request, group_id):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ content = parse_json_object_from_request(request)
+ result = yield self.groups_handler.remove_user_from_group(
+ group_id, requester_user_id, requester_user_id, content,
+ )
+
+ defer.returnValue((200, result))
+
+
+class GroupSelfJoinServlet(RestServlet):
+ """Attempt to join a group, or knock
+ """
+ PATTERNS = client_v2_patterns(
+ "/groups/(?P<group_id>[^/]*)/self/join$"
+ )
+
+ def __init__(self, hs):
+ super(GroupSelfJoinServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_PUT(self, request, group_id):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ content = parse_json_object_from_request(request)
+ result = yield self.groups_handler.join_group(
+ group_id, requester_user_id, content,
+ )
+
+ defer.returnValue((200, result))
+
+
+class GroupSelfAcceptInviteServlet(RestServlet):
+ """Accept a group invite
+ """
+ PATTERNS = client_v2_patterns(
+ "/groups/(?P<group_id>[^/]*)/self/accept_invite$"
+ )
+
+ def __init__(self, hs):
+ super(GroupSelfAcceptInviteServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_PUT(self, request, group_id):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ content = parse_json_object_from_request(request)
+ result = yield self.groups_handler.accept_invite(
+ group_id, requester_user_id, content,
+ )
+
+ defer.returnValue((200, result))
+
+
+class GroupSelfUpdatePublicityServlet(RestServlet):
+ """Update whether we publicise a users membership of a group
+ """
+ PATTERNS = client_v2_patterns(
+ "/groups/(?P<group_id>[^/]*)/self/update_publicity$"
+ )
+
+ def __init__(self, hs):
+ super(GroupSelfUpdatePublicityServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.store = hs.get_datastore()
+
+ @defer.inlineCallbacks
+ def on_PUT(self, request, group_id):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ content = parse_json_object_from_request(request)
+ publicise = content["publicise"]
+ yield self.store.update_group_publicity(
+ group_id, requester_user_id, publicise,
+ )
+
+ defer.returnValue((200, {}))
+
+
+class PublicisedGroupsForUserServlet(RestServlet):
+ """Get the list of groups a user is advertising
+ """
+ PATTERNS = client_v2_patterns(
+ "/publicised_groups/(?P<user_id>[^/]*)$"
+ )
+
+ def __init__(self, hs):
+ super(PublicisedGroupsForUserServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.store = hs.get_datastore()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, user_id):
+ yield self.auth.get_user_by_req(request, allow_guest=True)
+
+ result = yield self.groups_handler.get_publicised_groups_for_user(
+ user_id
+ )
+
+ defer.returnValue((200, result))
+
+
+class PublicisedGroupsForUsersServlet(RestServlet):
+ """Get the list of groups a user is advertising
+ """
+ PATTERNS = client_v2_patterns(
+ "/publicised_groups$"
+ )
+
+ def __init__(self, hs):
+ super(PublicisedGroupsForUsersServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.store = hs.get_datastore()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_POST(self, request):
+ yield self.auth.get_user_by_req(request, allow_guest=True)
+
+ content = parse_json_object_from_request(request)
+ user_ids = content["user_ids"]
+
+ result = yield self.groups_handler.bulk_get_publicised_groups(
+ user_ids
+ )
+
+ defer.returnValue((200, result))
+
+
+class GroupsForUserServlet(RestServlet):
+ """Get all groups the logged in user is joined to
+ """
+ PATTERNS = client_v2_patterns(
+ "/joined_groups$"
+ )
+
+ def __init__(self, hs):
+ super(GroupsForUserServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request):
+ requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ requester_user_id = requester.user.to_string()
+
+ result = yield self.groups_handler.get_joined_groups(requester_user_id)
+
+ defer.returnValue((200, result))
+
+
+def register_servlets(hs, http_server):
+ GroupServlet(hs).register(http_server)
+ GroupSummaryServlet(hs).register(http_server)
+ GroupInvitedUsersServlet(hs).register(http_server)
+ GroupUsersServlet(hs).register(http_server)
+ GroupRoomServlet(hs).register(http_server)
+ GroupSettingJoinPolicyServlet(hs).register(http_server)
+ GroupCreateServlet(hs).register(http_server)
+ GroupAdminRoomsServlet(hs).register(http_server)
+ GroupAdminRoomsConfigServlet(hs).register(http_server)
+ GroupAdminUsersInviteServlet(hs).register(http_server)
+ GroupAdminUsersKickServlet(hs).register(http_server)
+ GroupSelfLeaveServlet(hs).register(http_server)
+ GroupSelfJoinServlet(hs).register(http_server)
+ GroupSelfAcceptInviteServlet(hs).register(http_server)
+ GroupsForUserServlet(hs).register(http_server)
+ GroupCategoryServlet(hs).register(http_server)
+ GroupCategoriesServlet(hs).register(http_server)
+ GroupSummaryRoomsCatServlet(hs).register(http_server)
+ GroupRoleServlet(hs).register(http_server)
+ GroupRolesServlet(hs).register(http_server)
+ GroupSelfUpdatePublicityServlet(hs).register(http_server)
+ GroupSummaryUsersRoleServlet(hs).register(http_server)
+ PublicisedGroupsForUserServlet(hs).register(http_server)
+ PublicisedGroupsForUsersServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
index 6a3cfe84f8..8486086b51 100644
--- a/synapse/rest/client/v2_alpha/keys.py
+++ b/synapse/rest/client/v2_alpha/keys.py
@@ -19,10 +19,13 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.http.servlet import (
- RestServlet, parse_json_object_from_request, parse_integer
+ RestServlet,
+ parse_integer,
+ parse_json_object_from_request,
+ parse_string,
)
-from synapse.http.servlet import parse_string
from synapse.types import StreamToken
+
from ._base import client_v2_patterns
logger = logging.getLogger(__name__)
@@ -53,8 +56,7 @@ class KeyUploadServlet(RestServlet):
},
}
"""
- PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$",
- releases=())
+ PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$")
def __init__(self, hs):
"""
@@ -128,10 +130,7 @@ class KeyQueryServlet(RestServlet):
} } } } } }
"""
- PATTERNS = client_v2_patterns(
- "/keys/query$",
- releases=()
- )
+ PATTERNS = client_v2_patterns("/keys/query$")
def __init__(self, hs):
"""
@@ -160,10 +159,7 @@ class KeyChangesServlet(RestServlet):
200 OK
{ "changed": ["@foo:example.com"] }
"""
- PATTERNS = client_v2_patterns(
- "/keys/changes$",
- releases=()
- )
+ PATTERNS = client_v2_patterns("/keys/changes$")
def __init__(self, hs):
"""
@@ -188,13 +184,11 @@ class KeyChangesServlet(RestServlet):
user_id = requester.user.to_string()
- changed = yield self.device_handler.get_user_ids_changed(
+ results = yield self.device_handler.get_user_ids_changed(
user_id, from_token,
)
- defer.returnValue((200, {
- "changed": list(changed),
- }))
+ defer.returnValue((200, results))
class OneTimeKeyServlet(RestServlet):
@@ -215,10 +209,7 @@ class OneTimeKeyServlet(RestServlet):
} } } }
"""
- PATTERNS = client_v2_patterns(
- "/keys/claim$",
- releases=()
- )
+ PATTERNS = client_v2_patterns("/keys/claim$")
def __init__(self, hs):
super(OneTimeKeyServlet, self).__init__()
diff --git a/synapse/rest/client/v2_alpha/notifications.py b/synapse/rest/client/v2_alpha/notifications.py
index fd2a3d69d4..2a6ea3df5f 100644
--- a/synapse/rest/client/v2_alpha/notifications.py
+++ b/synapse/rest/client/v2_alpha/notifications.py
@@ -13,24 +13,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+
from twisted.internet import defer
-from synapse.http.servlet import (
- RestServlet, parse_string, parse_integer
-)
from synapse.events.utils import (
- serialize_event, format_event_for_client_v2_without_room_id,
+ format_event_for_client_v2_without_room_id,
+ serialize_event,
)
+from synapse.http.servlet import RestServlet, parse_integer, parse_string
from ._base import client_v2_patterns
-import logging
-
logger = logging.getLogger(__name__)
class NotificationsServlet(RestServlet):
- PATTERNS = client_v2_patterns("/notifications$", releases=())
+ PATTERNS = client_v2_patterns("/notifications$")
def __init__(self, hs):
super(NotificationsServlet, self).__init__()
@@ -88,7 +87,7 @@ class NotificationsServlet(RestServlet):
pa["topological_ordering"], pa["stream_ordering"]
)
returned_push_actions.append(returned_pa)
- next_token = pa["stream_ordering"]
+ next_token = str(pa["stream_ordering"])
defer.returnValue((200, {
"notifications": returned_push_actions,
diff --git a/synapse/rest/client/v2_alpha/openid.py b/synapse/rest/client/v2_alpha/openid.py
index aa1cae8e1e..01c90aa2a3 100644
--- a/synapse/rest/client/v2_alpha/openid.py
+++ b/synapse/rest/client/v2_alpha/openid.py
@@ -14,15 +14,15 @@
# limitations under the License.
-from ._base import client_v2_patterns
+import logging
+
+from twisted.internet import defer
-from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.api.errors import AuthError
+from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.util.stringutils import random_string
-from twisted.internet import defer
-
-import logging
+from ._base import client_v2_patterns
logger = logging.getLogger(__name__)
diff --git a/synapse/rest/client/v2_alpha/read_marker.py b/synapse/rest/client/v2_alpha/read_marker.py
index 2f8784fe06..a6e582a5ae 100644
--- a/synapse/rest/client/v2_alpha/read_marker.py
+++ b/synapse/rest/client/v2_alpha/read_marker.py
@@ -13,13 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+
from twisted.internet import defer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
-from ._base import client_v2_patterns
-
-import logging
+from ._base import client_v2_patterns
logger = logging.getLogger(__name__)
diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py
index 1fbff2edd8..de370cac45 100644
--- a/synapse/rest/client/v2_alpha/receipts.py
+++ b/synapse/rest/client/v2_alpha/receipts.py
@@ -13,14 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+
from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet
-from ._base import client_v2_patterns
-
-import logging
+from ._base import client_v2_patterns
logger = logging.getLogger(__name__)
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 1421c18152..d6cf915d86 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -14,25 +14,29 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import hmac
+import logging
+from hashlib import sha1
+
+from six import string_types
+
from twisted.internet import defer
import synapse
-from synapse.api.auth import get_access_token_from_request, has_access_token
+import synapse.types
from synapse.api.constants import LoginType
-from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError
+from synapse.api.errors import Codes, SynapseError, UnrecognizedRequestError
from synapse.http.servlet import (
- RestServlet, parse_json_object_from_request, assert_params_in_request, parse_string
+ RestServlet,
+ assert_params_in_dict,
+ parse_json_object_from_request,
+ parse_string,
)
from synapse.util.msisdn import phone_number_to_msisdn
-
-from ._base import client_v2_patterns
-
-import logging
-import hmac
-from hashlib import sha1
-from synapse.util.async import run_on_reactor
from synapse.util.ratelimitutils import FederationRateLimiter
+from synapse.util.threepids import check_3pid_allowed
+from ._base import client_v2_patterns, interactive_auth_handler
# We ought to be using hmac.compare_digest() but on older pythons it doesn't
# exist. It's a _really minor_ security flaw to use plain string comparison
@@ -64,10 +68,15 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
def on_POST(self, request):
body = parse_json_object_from_request(request)
- assert_params_in_request(body, [
+ assert_params_in_dict(body, [
'id_server', 'client_secret', 'email', 'send_attempt'
])
+ if not check_3pid_allowed(self.hs, "email", body['email']):
+ raise SynapseError(
+ 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
+ )
+
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
'email', body['email']
)
@@ -95,7 +104,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
def on_POST(self, request):
body = parse_json_object_from_request(request)
- assert_params_in_request(body, [
+ assert_params_in_dict(body, [
'id_server', 'client_secret',
'country', 'phone_number',
'send_attempt',
@@ -103,6 +112,11 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
+ if not check_3pid_allowed(self.hs, "msisdn", msisdn):
+ raise SynapseError(
+ 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
+ )
+
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
'msisdn', msisdn
)
@@ -170,13 +184,13 @@ class RegisterRestServlet(RestServlet):
self.auth_handler = hs.get_auth_handler()
self.registration_handler = hs.get_handlers().registration_handler
self.identity_handler = hs.get_handlers().identity_handler
+ self.room_member_handler = hs.get_room_member_handler()
self.device_handler = hs.get_device_handler()
self.macaroon_gen = hs.get_macaroon_generator()
+ @interactive_auth_handler
@defer.inlineCallbacks
def on_POST(self, request):
- yield run_on_reactor()
-
body = parse_json_object_from_request(request)
kind = "user"
@@ -196,20 +210,20 @@ class RegisterRestServlet(RestServlet):
# in sessions. Pull out the username/password provided to us.
desired_password = None
if 'password' in body:
- if (not isinstance(body['password'], basestring) or
+ if (not isinstance(body['password'], string_types) or
len(body['password']) > 512):
raise SynapseError(400, "Invalid password")
desired_password = body["password"]
desired_username = None
if 'username' in body:
- if (not isinstance(body['username'], basestring) or
+ if (not isinstance(body['username'], string_types) or
len(body['username']) > 512):
raise SynapseError(400, "Invalid username")
desired_username = body['username']
appservice = None
- if has_access_token(request):
+ if self.auth.has_access_token(request):
appservice = yield self.auth.get_appservice_by_req(request)
# fork off as soon as possible for ASes and shared secret auth which
@@ -221,15 +235,30 @@ class RegisterRestServlet(RestServlet):
# 'user' key not 'username'). Since this is a new addition, we'll
# fallback to 'username' if they gave one.
desired_username = body.get("user", desired_username)
- access_token = get_access_token_from_request(request)
- if isinstance(desired_username, basestring):
+ # XXX we should check that desired_username is valid. Currently
+ # we give appservices carte blanche for any insanity in mxids,
+ # because the IRC bridges rely on being able to register stupid
+ # IDs.
+
+ access_token = self.auth.get_access_token_from_request(request)
+
+ if isinstance(desired_username, string_types):
result = yield self._do_appservice_registration(
desired_username, access_token, body
)
defer.returnValue((200, result)) # we throw for non 200 responses
return
+ # for either shared secret or regular registration, downcase the
+ # provided username before attempting to register it. This should mean
+ # that people who try to register with upper-case in their usernames
+ # don't get a nasty surprise. (Note that we treat username
+ # case-insenstively in login, so they are free to carry on imagining
+ # that their username is CrAzYh4cKeR if that keeps them happy)
+ if desired_username is not None:
+ desired_username = desired_username.lower()
+
# == Shared Secret Registration == (e.g. create new user scripts)
if 'mac' in body:
# FIXME: Should we really be determining if this is shared secret
@@ -286,34 +315,66 @@ class RegisterRestServlet(RestServlet):
if 'x_show_msisdn' in body and body['x_show_msisdn']:
show_msisdn = True
+ # FIXME: need a better error than "no auth flow found" for scenarios
+ # where we required 3PID for registration but the user didn't give one
+ require_email = 'email' in self.hs.config.registrations_require_3pid
+ require_msisdn = 'msisdn' in self.hs.config.registrations_require_3pid
+
+ flows = []
if self.hs.config.enable_registration_captcha:
- flows = [
- [LoginType.RECAPTCHA],
- [LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA],
- ]
+ # only support 3PIDless registration if no 3PIDs are required
+ if not require_email and not require_msisdn:
+ flows.extend([[LoginType.RECAPTCHA]])
+ # only support the email-only flow if we don't require MSISDN 3PIDs
+ if not require_msisdn:
+ flows.extend([[LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA]])
+
if show_msisdn:
+ # only support the MSISDN-only flow if we don't require email 3PIDs
+ if not require_email:
+ flows.extend([[LoginType.MSISDN, LoginType.RECAPTCHA]])
+ # always let users provide both MSISDN & email
flows.extend([
- [LoginType.MSISDN, LoginType.RECAPTCHA],
[LoginType.MSISDN, LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA],
])
else:
- flows = [
- [LoginType.DUMMY],
- [LoginType.EMAIL_IDENTITY],
- ]
+ # only support 3PIDless registration if no 3PIDs are required
+ if not require_email and not require_msisdn:
+ flows.extend([[LoginType.DUMMY]])
+ # only support the email-only flow if we don't require MSISDN 3PIDs
+ if not require_msisdn:
+ flows.extend([[LoginType.EMAIL_IDENTITY]])
+
if show_msisdn:
+ # only support the MSISDN-only flow if we don't require email 3PIDs
+ if not require_email or require_msisdn:
+ flows.extend([[LoginType.MSISDN]])
+ # always let users provide both MSISDN & email
flows.extend([
- [LoginType.MSISDN],
- [LoginType.MSISDN, LoginType.EMAIL_IDENTITY],
+ [LoginType.MSISDN, LoginType.EMAIL_IDENTITY]
])
- authed, auth_result, params, session_id = yield self.auth_handler.check_auth(
+ auth_result, params, session_id = yield self.auth_handler.check_auth(
flows, body, self.hs.get_ip_from_request(request)
)
- if not authed:
- defer.returnValue((401, auth_result))
- return
+ # Check that we're not trying to register a denied 3pid.
+ #
+ # the user-facing checks will probably already have happened in
+ # /register/email/requestToken when we requested a 3pid, but that's not
+ # guaranteed.
+
+ if auth_result:
+ for login_type in [LoginType.EMAIL_IDENTITY, LoginType.MSISDN]:
+ if login_type in auth_result:
+ medium = auth_result[login_type]['medium']
+ address = auth_result[login_type]['address']
+
+ if not check_3pid_allowed(self.hs, medium, address):
+ raise SynapseError(
+ 403, "Third party identifier is not allowed",
+ Codes.THREEPID_DENIED,
+ )
if registered_user_id is not None:
logger.info(
@@ -325,14 +386,15 @@ class RegisterRestServlet(RestServlet):
add_msisdn = False
else:
# NB: This may be from the auth handler and NOT from the POST
- if 'password' not in params:
- raise SynapseError(400, "Missing password.",
- Codes.MISSING_PARAM)
+ assert_params_in_dict(params, ["password"])
desired_username = params.get("username", None)
new_password = params.get("password", None)
guest_access_token = params.get("guest_access_token", None)
+ if desired_username is not None:
+ desired_username = desired_username.lower()
+
(registered_user_id, _) = yield self.registration_handler.register(
localpart=desired_username,
password=new_password,
@@ -383,15 +445,24 @@ class RegisterRestServlet(RestServlet):
def _do_shared_secret_registration(self, username, password, body):
if not self.hs.config.registration_shared_secret:
raise SynapseError(400, "Shared secret registration is not enabled")
+ if not username:
+ raise SynapseError(
+ 400, "username must be specified", errcode=Codes.BAD_JSON,
+ )
- user = username.encode("utf-8")
+ # use the username from the original request rather than the
+ # downcased one in `username` for the mac calculation
+ user = body["username"].encode("utf-8")
# str() because otherwise hmac complains that 'unicode' does not
# have the buffer interface
got_mac = str(body["mac"])
+ # FIXME this is different to the /v1/register endpoint, which
+ # includes the password and admin flag in the hashed text. Why are
+ # these different?
want_mac = hmac.new(
- key=self.hs.config.registration_shared_secret,
+ key=self.hs.config.registration_shared_secret.encode(),
msg=user,
digestmod=sha1,
).hexdigest()
@@ -492,11 +563,14 @@ class RegisterRestServlet(RestServlet):
Returns:
defer.Deferred:
"""
- reqd = ('medium', 'address', 'validated_at')
- if any(x not in threepid for x in reqd):
- # This will only happen if the ID server returns a malformed response
- logger.info("Can't add incomplete 3pid")
- defer.returnValue()
+ try:
+ assert_params_in_dict(threepid, ['medium', 'address', 'validated_at'])
+ except SynapseError as ex:
+ if ex.errcode == Codes.MISSING_PARAM:
+ # This will only happen if the ID server returns a malformed response
+ logger.info("Can't add incomplete 3pid")
+ defer.returnValue(None)
+ raise
yield self.auth_handler.add_threepid(
user_id,
@@ -523,25 +597,28 @@ class RegisterRestServlet(RestServlet):
Args:
(str) user_id: full canonical @user:id
(object) params: registration parameters, from which we pull
- device_id and initial_device_name
+ device_id, initial_device_name and inhibit_login
Returns:
defer.Deferred: (object) dictionary for response from /register
"""
- device_id = yield self._register_device(user_id, params)
+ result = {
+ "user_id": user_id,
+ "home_server": self.hs.hostname,
+ }
+ if not params.get("inhibit_login", False):
+ device_id = yield self._register_device(user_id, params)
- access_token = (
- yield self.auth_handler.get_access_token_for_user_id(
- user_id, device_id=device_id,
- initial_display_name=params.get("initial_device_display_name")
+ access_token = (
+ yield self.auth_handler.get_access_token_for_user_id(
+ user_id, device_id=device_id,
+ )
)
- )
- defer.returnValue({
- "user_id": user_id,
- "access_token": access_token,
- "home_server": self.hs.hostname,
- "device_id": device_id,
- })
+ result.update({
+ "access_token": access_token,
+ "device_id": device_id,
+ })
+ defer.returnValue(result)
def _register_device(self, user_id, params):
"""Register a device for a user.
@@ -566,7 +643,7 @@ class RegisterRestServlet(RestServlet):
@defer.inlineCallbacks
def _do_guest_registration(self, params):
if not self.hs.config.allow_guest_access:
- defer.returnValue((403, "Guest access is disabled"))
+ raise SynapseError(403, "Guest access is disabled")
user_id, _ = yield self.registration_handler.register(
generate_token=False,
make_guest=True
diff --git a/synapse/rest/client/v2_alpha/report_event.py b/synapse/rest/client/v2_alpha/report_event.py
index 8903e12405..95d2a71ec2 100644
--- a/synapse/rest/client/v2_alpha/report_event.py
+++ b/synapse/rest/client/v2_alpha/report_event.py
@@ -13,13 +13,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
+import logging
-from synapse.http.servlet import RestServlet, parse_json_object_from_request
-from ._base import client_v2_patterns
+from six import string_types
+from six.moves import http_client
-import logging
+from twisted.internet import defer
+
+from synapse.api.errors import Codes, SynapseError
+from synapse.http.servlet import (
+ RestServlet,
+ assert_params_in_dict,
+ parse_json_object_from_request,
+)
+from ._base import client_v2_patterns
logger = logging.getLogger(__name__)
@@ -42,12 +50,26 @@ class ReportEventRestServlet(RestServlet):
user_id = requester.user.to_string()
body = parse_json_object_from_request(request)
+ assert_params_in_dict(body, ("reason", "score"))
+
+ if not isinstance(body["reason"], string_types):
+ raise SynapseError(
+ http_client.BAD_REQUEST,
+ "Param 'reason' must be a string",
+ Codes.BAD_JSON,
+ )
+ if not isinstance(body["score"], int):
+ raise SynapseError(
+ http_client.BAD_REQUEST,
+ "Param 'score' must be an integer",
+ Codes.BAD_JSON,
+ )
yield self.store.add_event_report(
room_id=room_id,
event_id=event_id,
user_id=user_id,
- reason=body.get("reason"),
+ reason=body["reason"],
content=body,
received_ts=self.clock.time_msec(),
)
diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py
index d607bd2970..a9e9a47a0b 100644
--- a/synapse/rest/client/v2_alpha/sendtodevice.py
+++ b/synapse/rest/client/v2_alpha/sendtodevice.py
@@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
class SendToDeviceRestServlet(servlet.RestServlet):
PATTERNS = client_v2_patterns(
"/sendToDevice/(?P<message_type>[^/]*)/(?P<txn_id>[^/]*)$",
- releases=[], v2_alpha=False
+ v2_alpha=False
)
def __init__(self, hs):
@@ -40,7 +40,7 @@ class SendToDeviceRestServlet(servlet.RestServlet):
super(SendToDeviceRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
- self.txns = HttpTransactionCache(hs.get_clock())
+ self.txns = HttpTransactionCache(hs)
self.device_message_handler = hs.get_device_message_handler()
def on_PUT(self, request, message_type, txn_id):
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index 83e209d18f..8aa06faf23 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -13,27 +13,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import itertools
+import logging
+
+from canonicaljson import json
+
from twisted.internet import defer
-from synapse.http.servlet import (
- RestServlet, parse_string, parse_integer, parse_boolean
+from synapse.api.constants import PresenceState
+from synapse.api.errors import SynapseError
+from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection
+from synapse.events.utils import (
+ format_event_for_client_v2_without_room_id,
+ serialize_event,
)
from synapse.handlers.presence import format_user_presence_state
from synapse.handlers.sync import SyncConfig
+from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
from synapse.types import StreamToken
-from synapse.events.utils import (
- serialize_event, format_event_for_client_v2_without_room_id,
-)
-from synapse.api.filtering import FilterCollection, DEFAULT_FILTER_COLLECTION
-from synapse.api.errors import SynapseError
-from synapse.api.constants import PresenceState
-from ._base import client_v2_patterns
-from ._base import set_timeline_upper_limit
-
-import itertools
-import logging
-import ujson as json
+from ._base import client_v2_patterns, set_timeline_upper_limit
logger = logging.getLogger(__name__)
@@ -85,6 +84,7 @@ class SyncRestServlet(RestServlet):
self.clock = hs.get_clock()
self.filtering = hs.get_filtering()
self.presence_handler = hs.get_presence_handler()
+ self._server_notices_sender = hs.get_server_notices_sender()
@defer.inlineCallbacks
def on_GET(self, request):
@@ -110,7 +110,7 @@ class SyncRestServlet(RestServlet):
filter_id = parse_string(request, "filter", default=None)
full_state = parse_boolean(request, "full_state", default=False)
- logger.info(
+ logger.debug(
"/sync: user=%r, timeout=%r, since=%r,"
" set_presence=%r, filter_id=%r, device_id=%r" % (
user, timeout, since, set_presence, filter_id, device_id
@@ -125,7 +125,7 @@ class SyncRestServlet(RestServlet):
filter_object = json.loads(filter_id)
set_timeline_upper_limit(filter_object,
self.hs.config.filter_timeline_limit)
- except:
+ except Exception:
raise SynapseError(400, "Invalid filter JSON")
self.filtering.check_valid_filter(filter_object)
filter = FilterCollection(filter_object)
@@ -149,6 +149,9 @@ class SyncRestServlet(RestServlet):
else:
since_token = None
+ # send any outstanding server notices to the user.
+ yield self._server_notices_sender.on_user_syncing(user.to_string())
+
affect_presence = set_presence != PresenceState.OFFLINE
if affect_presence:
@@ -164,27 +167,35 @@ class SyncRestServlet(RestServlet):
)
time_now = self.clock.time_msec()
+ response_content = self.encode_response(
+ time_now, sync_result, requester.access_token_id, filter
+ )
+
+ defer.returnValue((200, response_content))
- joined = self.encode_joined(
- sync_result.joined, time_now, requester.access_token_id, filter.event_fields
+ @staticmethod
+ def encode_response(time_now, sync_result, access_token_id, filter):
+ joined = SyncRestServlet.encode_joined(
+ sync_result.joined, time_now, access_token_id, filter.event_fields
)
- invited = self.encode_invited(
- sync_result.invited, time_now, requester.access_token_id
+ invited = SyncRestServlet.encode_invited(
+ sync_result.invited, time_now, access_token_id,
)
- archived = self.encode_archived(
- sync_result.archived, time_now, requester.access_token_id,
+ archived = SyncRestServlet.encode_archived(
+ sync_result.archived, time_now, access_token_id,
filter.event_fields,
)
- response_content = {
+ return {
"account_data": {"events": sync_result.account_data},
"to_device": {"events": sync_result.to_device},
"device_lists": {
- "changed": list(sync_result.device_lists),
+ "changed": list(sync_result.device_lists.changed),
+ "left": list(sync_result.device_lists.left),
},
- "presence": self.encode_presence(
+ "presence": SyncRestServlet.encode_presence(
sync_result.presence, time_now
),
"rooms": {
@@ -192,13 +203,17 @@ class SyncRestServlet(RestServlet):
"invite": invited,
"leave": archived,
},
+ "groups": {
+ "join": sync_result.groups.join,
+ "invite": sync_result.groups.invite,
+ "leave": sync_result.groups.leave,
+ },
"device_one_time_keys_count": sync_result.device_one_time_keys_count,
"next_batch": sync_result.next_batch.to_string(),
}
- defer.returnValue((200, response_content))
-
- def encode_presence(self, events, time_now):
+ @staticmethod
+ def encode_presence(events, time_now):
return {
"events": [
{
@@ -212,7 +227,8 @@ class SyncRestServlet(RestServlet):
]
}
- def encode_joined(self, rooms, time_now, token_id, event_fields):
+ @staticmethod
+ def encode_joined(rooms, time_now, token_id, event_fields):
"""
Encode the joined rooms in a sync result
@@ -231,13 +247,14 @@ class SyncRestServlet(RestServlet):
"""
joined = {}
for room in rooms:
- joined[room.room_id] = self.encode_room(
+ joined[room.room_id] = SyncRestServlet.encode_room(
room, time_now, token_id, only_fields=event_fields
)
return joined
- def encode_invited(self, rooms, time_now, token_id):
+ @staticmethod
+ def encode_invited(rooms, time_now, token_id):
"""
Encode the invited rooms in a sync result
@@ -270,7 +287,8 @@ class SyncRestServlet(RestServlet):
return invited
- def encode_archived(self, rooms, time_now, token_id, event_fields):
+ @staticmethod
+ def encode_archived(rooms, time_now, token_id, event_fields):
"""
Encode the archived rooms in a sync result
@@ -289,7 +307,7 @@ class SyncRestServlet(RestServlet):
"""
joined = {}
for room in rooms:
- joined[room.room_id] = self.encode_room(
+ joined[room.room_id] = SyncRestServlet.encode_room(
room, time_now, token_id, joined=False, only_fields=event_fields
)
diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py
index dac8603b07..4fea614e95 100644
--- a/synapse/rest/client/v2_alpha/tags.py
+++ b/synapse/rest/client/v2_alpha/tags.py
@@ -13,14 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import client_v2_patterns
-
-from synapse.http.servlet import RestServlet, parse_json_object_from_request
-from synapse.api.errors import AuthError
+import logging
from twisted.internet import defer
-import logging
+from synapse.api.errors import AuthError
+from synapse.http.servlet import RestServlet, parse_json_object_from_request
+
+from ._base import client_v2_patterns
logger = logging.getLogger(__name__)
diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py
index 6fceb23e26..d9d379182e 100644
--- a/synapse/rest/client/v2_alpha/thirdparty.py
+++ b/synapse/rest/client/v2_alpha/thirdparty.py
@@ -20,13 +20,14 @@ from twisted.internet import defer
from synapse.api.constants import ThirdPartyEntityKind
from synapse.http.servlet import RestServlet
+
from ._base import client_v2_patterns
logger = logging.getLogger(__name__)
class ThirdPartyProtocolsServlet(RestServlet):
- PATTERNS = client_v2_patterns("/thirdparty/protocols", releases=())
+ PATTERNS = client_v2_patterns("/thirdparty/protocols")
def __init__(self, hs):
super(ThirdPartyProtocolsServlet, self).__init__()
@@ -43,8 +44,7 @@ class ThirdPartyProtocolsServlet(RestServlet):
class ThirdPartyProtocolServlet(RestServlet):
- PATTERNS = client_v2_patterns("/thirdparty/protocol/(?P<protocol>[^/]+)$",
- releases=())
+ PATTERNS = client_v2_patterns("/thirdparty/protocol/(?P<protocol>[^/]+)$")
def __init__(self, hs):
super(ThirdPartyProtocolServlet, self).__init__()
@@ -66,8 +66,7 @@ class ThirdPartyProtocolServlet(RestServlet):
class ThirdPartyUserServlet(RestServlet):
- PATTERNS = client_v2_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$",
- releases=())
+ PATTERNS = client_v2_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$")
def __init__(self, hs):
super(ThirdPartyUserServlet, self).__init__()
@@ -90,8 +89,7 @@ class ThirdPartyUserServlet(RestServlet):
class ThirdPartyLocationServlet(RestServlet):
- PATTERNS = client_v2_patterns("/thirdparty/location(/(?P<protocol>[^/]+))?$",
- releases=())
+ PATTERNS = client_v2_patterns("/thirdparty/location(/(?P<protocol>[^/]+))?$")
def __init__(self, hs):
super(ThirdPartyLocationServlet, self).__init__()
diff --git a/synapse/rest/client/v2_alpha/user_directory.py b/synapse/rest/client/v2_alpha/user_directory.py
index 6e012da4aa..cac0624ba7 100644
--- a/synapse/rest/client/v2_alpha/user_directory.py
+++ b/synapse/rest/client/v2_alpha/user_directory.py
@@ -19,6 +19,7 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
+
from ._base import client_v2_patterns
logger = logging.getLogger(__name__)
@@ -65,7 +66,7 @@ class UserDirectorySearchRestServlet(RestServlet):
try:
search_term = body["search_term"]
- except:
+ except Exception:
raise SynapseError(400, "`search_term` is required field")
results = yield self.user_directory_handler.search_users(
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index e984ea47db..6ac2987b98 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -13,11 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.http.servlet import RestServlet
-
import logging
import re
+from synapse.http.servlet import RestServlet
+
logger = logging.getLogger(__name__)
@@ -30,6 +30,7 @@ class VersionsRestServlet(RestServlet):
"r0.0.1",
"r0.1.0",
"r0.2.0",
+ "r0.3.0",
]
})
|