diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py
index 465b25033d..1197158fdc 100644
--- a/synapse/rest/client/v1/admin.py
+++ b/synapse/rest/client/v1/admin.py
@@ -137,7 +137,7 @@ class DeactivateAccountRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/admin/deactivate/(?P<target_user_id>[^/]*)")
def __init__(self, hs):
- self.store = hs.get_datastore()
+ self._auth_handler = hs.get_auth_handler()
super(DeactivateAccountRestServlet, self).__init__(hs)
@defer.inlineCallbacks
@@ -149,12 +149,7 @@ class DeactivateAccountRestServlet(ClientV1RestServlet):
if not is_admin:
raise AuthError(403, "You are not a server admin")
- # FIXME: Theoretically there is a race here wherein user resets password
- # using threepid.
- yield self.store.user_delete_access_tokens(target_user_id)
- yield self.store.user_delete_threepids(target_user_id)
- yield self.store.user_set_password_hash(target_user_id, None)
-
+ yield self._auth_handler.deactivate_account(target_user_id)
defer.returnValue((200, {}))
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 9536e8ade6..5669ecb724 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -85,7 +85,6 @@ def login_id_thirdparty_from_phone(identifier):
class LoginRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/login$")
- PASS_TYPE = "m.login.password"
SAML2_TYPE = "m.login.saml2"
CAS_TYPE = "m.login.cas"
TOKEN_TYPE = "m.login.token"
@@ -94,7 +93,6 @@ class LoginRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(LoginRestServlet, self).__init__(hs)
self.idp_redirect_url = hs.config.saml2_idp_redirect_url
- self.password_enabled = hs.config.password_enabled
self.saml2_enabled = hs.config.saml2_enabled
self.jwt_enabled = hs.config.jwt_enabled
self.jwt_secret = hs.config.jwt_secret
@@ -121,8 +119,10 @@ class LoginRestServlet(ClientV1RestServlet):
# fall back to the fallback API if they don't understand one of the
# login flow types returned.
flows.append({"type": LoginRestServlet.TOKEN_TYPE})
- if self.password_enabled:
- flows.append({"type": LoginRestServlet.PASS_TYPE})
+
+ flows.extend((
+ {"type": t} for t in self.auth_handler.get_supported_login_types()
+ ))
return (200, {"flows": flows})
@@ -133,14 +133,8 @@ class LoginRestServlet(ClientV1RestServlet):
def on_POST(self, request):
login_submission = parse_json_object_from_request(request)
try:
- if login_submission["type"] == LoginRestServlet.PASS_TYPE:
- if not self.password_enabled:
- raise SynapseError(400, "Password login has been disabled.")
-
- result = yield self.do_password_login(login_submission)
- defer.returnValue(result)
- elif self.saml2_enabled and (login_submission["type"] ==
- LoginRestServlet.SAML2_TYPE):
+ if self.saml2_enabled and (login_submission["type"] ==
+ LoginRestServlet.SAML2_TYPE):
relay_state = ""
if "relay_state" in login_submission:
relay_state = "&RelayState=" + urllib.quote(
@@ -157,15 +151,31 @@ class LoginRestServlet(ClientV1RestServlet):
result = yield self.do_token_login(login_submission)
defer.returnValue(result)
else:
- raise SynapseError(400, "Bad login type.")
+ result = yield self._do_other_login(login_submission)
+ defer.returnValue(result)
except KeyError:
raise SynapseError(400, "Missing JSON keys.")
@defer.inlineCallbacks
- def do_password_login(self, login_submission):
- if "password" not in login_submission:
- raise SynapseError(400, "Missing parameter: password")
+ def _do_other_login(self, login_submission):
+ """Handle non-token/saml/jwt logins
+ Args:
+ login_submission:
+
+ Returns:
+ (int, object): HTTP code/response
+ """
+ # Log the request we got, but only certain fields to minimise the chance of
+ # logging someone's password (even if they accidentally put it in the wrong
+ # field)
+ logger.info(
+ "Got login request with identifier: %r, medium: %r, address: %r, user: %r",
+ login_submission.get('identifier'),
+ login_submission.get('medium'),
+ login_submission.get('address'),
+ login_submission.get('user'),
+ )
login_submission_legacy_convert(login_submission)
if "identifier" not in login_submission:
@@ -208,30 +218,29 @@ class LoginRestServlet(ClientV1RestServlet):
if "user" not in identifier:
raise SynapseError(400, "User identifier is missing 'user' key")
- user_id = identifier["user"]
-
- if not user_id.startswith('@'):
- user_id = UserID(
- user_id, self.hs.hostname
- ).to_string()
-
auth_handler = self.auth_handler
- user_id = yield auth_handler.validate_password_login(
- user_id=user_id,
- password=login_submission["password"],
+ canonical_user_id, callback = yield auth_handler.validate_login(
+ identifier["user"],
+ login_submission,
+ )
+
+ device_id = yield self._register_device(
+ canonical_user_id, login_submission,
)
- device_id = yield self._register_device(user_id, login_submission)
access_token = yield auth_handler.get_access_token_for_user_id(
- user_id, device_id,
- login_submission.get("initial_device_display_name"),
+ canonical_user_id, device_id,
)
+
result = {
- "user_id": user_id, # may have changed
+ "user_id": canonical_user_id,
"access_token": access_token,
"home_server": self.hs.hostname,
"device_id": device_id,
}
+ if callback is not None:
+ yield callback(result)
+
defer.returnValue((200, result))
@defer.inlineCallbacks
@@ -244,7 +253,6 @@ class LoginRestServlet(ClientV1RestServlet):
device_id = yield self._register_device(user_id, login_submission)
access_token = yield auth_handler.get_access_token_for_user_id(
user_id, device_id,
- login_submission.get("initial_device_display_name"),
)
result = {
"user_id": user_id, # may have changed
@@ -287,7 +295,6 @@ class LoginRestServlet(ClientV1RestServlet):
)
access_token = yield auth_handler.get_access_token_for_user_id(
registered_user_id, device_id,
- login_submission.get("initial_device_display_name"),
)
result = {
diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py
index 1358d0acab..6add754782 100644
--- a/synapse/rest/client/v1/logout.py
+++ b/synapse/rest/client/v1/logout.py
@@ -30,7 +30,7 @@ class LogoutRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(LogoutRestServlet, self).__init__(hs)
- self.store = hs.get_datastore()
+ self._auth_handler = hs.get_auth_handler()
def on_OPTIONS(self, request):
return (200, {})
@@ -38,7 +38,7 @@ class LogoutRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_POST(self, request):
access_token = get_access_token_from_request(request)
- yield self.store.delete_access_token(access_token)
+ yield self._auth_handler.delete_access_token(access_token)
defer.returnValue((200, {}))
@@ -47,8 +47,8 @@ class LogoutAllRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(LogoutAllRestServlet, self).__init__(hs)
- self.store = hs.get_datastore()
self.auth = hs.get_auth()
+ self._auth_handler = hs.get_auth_handler()
def on_OPTIONS(self, request):
return (200, {})
@@ -57,7 +57,7 @@ class LogoutAllRestServlet(ClientV1RestServlet):
def on_POST(self, request):
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
- yield self.store.user_delete_access_tokens(user_id)
+ yield self._auth_handler.delete_access_tokens_for_user(user_id)
defer.returnValue((200, {}))
diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py
index ecf7e311a9..32ed1d3ab2 100644
--- a/synapse/rest/client/v1/register.py
+++ b/synapse/rest/client/v1/register.py
@@ -359,7 +359,7 @@ class RegisterRestServlet(ClientV1RestServlet):
if compare_digest(want_mac, got_mac):
handler = self.handlers.registration_handler
user_id, token = yield handler.register(
- localpart=user,
+ localpart=user.lower(),
password=password,
admin=bool(admin),
)
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index 4990b22b9f..726e0a2826 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -13,22 +13,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
from twisted.internet import defer
+from synapse.api.auth import has_access_token
from synapse.api.constants import LoginType
-from synapse.api.errors import LoginError, SynapseError, Codes
+from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.http.servlet import (
- RestServlet, parse_json_object_from_request, assert_params_in_request
+ RestServlet, assert_params_in_request,
+ parse_json_object_from_request,
)
from synapse.util.async import run_on_reactor
from synapse.util.msisdn import phone_number_to_msisdn
-
from ._base import client_v2_patterns
-import logging
-
-
logger = logging.getLogger(__name__)
@@ -163,7 +162,6 @@ class DeactivateAccountRestServlet(RestServlet):
def __init__(self, hs):
self.hs = hs
- self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
super(DeactivateAccountRestServlet, self).__init__()
@@ -172,6 +170,20 @@ class DeactivateAccountRestServlet(RestServlet):
def on_POST(self, request):
body = parse_json_object_from_request(request)
+ # if the caller provides an access token, it ought to be valid.
+ requester = None
+ if has_access_token(request):
+ requester = yield self.auth.get_user_by_req(
+ request,
+ ) # type: synapse.types.Requester
+
+ # allow ASes to dectivate their own users
+ if requester and requester.app_service:
+ yield self.auth_handler.deactivate_account(
+ requester.user.to_string()
+ )
+ defer.returnValue((200, {}))
+
authed, result, params, _ = yield self.auth_handler.check_auth([
[LoginType.PASSWORD],
], body, self.hs.get_ip_from_request(request))
@@ -179,25 +191,22 @@ class DeactivateAccountRestServlet(RestServlet):
if not authed:
defer.returnValue((401, result))
- user_id = None
- requester = None
-
if LoginType.PASSWORD in result:
+ user_id = result[LoginType.PASSWORD]
# if using password, they should also be logged in
- requester = yield self.auth.get_user_by_req(request)
- user_id = requester.user.to_string()
- if user_id != result[LoginType.PASSWORD]:
+ if requester is None:
+ raise SynapseError(
+ 400,
+ "Deactivate account requires an access_token",
+ errcode=Codes.MISSING_TOKEN
+ )
+ if requester.user.to_string() != user_id:
raise LoginError(400, "", Codes.UNKNOWN)
else:
logger.error("Auth succeeded but no known type!", result.keys())
raise SynapseError(500, "", Codes.UNKNOWN)
- # FIXME: Theoretically there is a race here wherein user resets password
- # using threepid.
- yield self.store.user_delete_access_tokens(user_id)
- yield self.store.user_delete_threepids(user_id)
- yield self.store.user_set_password_hash(user_id, None)
-
+ yield self.auth_handler.deactivate_account(user_id)
defer.returnValue((200, {}))
@@ -373,6 +382,20 @@ class ThreepidDeleteRestServlet(RestServlet):
defer.returnValue((200, {}))
+class WhoamiRestServlet(RestServlet):
+ PATTERNS = client_v2_patterns("/account/whoami$")
+
+ def __init__(self, hs):
+ super(WhoamiRestServlet, self).__init__()
+ self.auth = hs.get_auth()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request):
+ requester = yield self.auth.get_user_by_req(request)
+
+ defer.returnValue((200, {'user_id': requester.user.to_string()}))
+
+
def register_servlets(hs, http_server):
EmailPasswordRequestTokenRestServlet(hs).register(http_server)
MsisdnPasswordRequestTokenRestServlet(hs).register(http_server)
@@ -382,3 +405,4 @@ def register_servlets(hs, http_server):
MsisdnThreepidRequestTokenRestServlet(hs).register(http_server)
ThreepidRestServlet(hs).register(http_server)
ThreepidDeleteRestServlet(hs).register(http_server)
+ WhoamiRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py
index b57ba95d24..5321e5abbb 100644
--- a/synapse/rest/client/v2_alpha/devices.py
+++ b/synapse/rest/client/v2_alpha/devices.py
@@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
class DevicesRestServlet(servlet.RestServlet):
- PATTERNS = client_v2_patterns("/devices$", releases=[], v2_alpha=False)
+ PATTERNS = client_v2_patterns("/devices$", v2_alpha=False)
def __init__(self, hs):
"""
@@ -51,7 +51,7 @@ class DeleteDevicesRestServlet(servlet.RestServlet):
API for bulk deletion of devices. Accepts a JSON object with a devices
key which lists the device_ids to delete. Requires user interactive auth.
"""
- PATTERNS = client_v2_patterns("/delete_devices", releases=[], v2_alpha=False)
+ PATTERNS = client_v2_patterns("/delete_devices", v2_alpha=False)
def __init__(self, hs):
super(DeleteDevicesRestServlet, self).__init__()
@@ -93,8 +93,7 @@ class DeleteDevicesRestServlet(servlet.RestServlet):
class DeviceRestServlet(servlet.RestServlet):
- PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$",
- releases=[], v2_alpha=False)
+ PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$", v2_alpha=False)
def __init__(self, hs):
"""
@@ -118,6 +117,8 @@ class DeviceRestServlet(servlet.RestServlet):
@defer.inlineCallbacks
def on_DELETE(self, request, device_id):
+ requester = yield self.auth.get_user_by_req(request)
+
try:
body = servlet.parse_json_object_from_request(request)
@@ -136,11 +137,12 @@ class DeviceRestServlet(servlet.RestServlet):
if not authed:
defer.returnValue((401, result))
- requester = yield self.auth.get_user_by_req(request)
- yield self.device_handler.delete_device(
- requester.user.to_string(),
- device_id,
- )
+ # check that the UI auth matched the access token
+ user_id = result[constants.LoginType.PASSWORD]
+ if user_id != requester.user.to_string():
+ raise errors.AuthError(403, "Invalid auth")
+
+ yield self.device_handler.delete_device(user_id, device_id)
defer.returnValue((200, {}))
@defer.inlineCallbacks
diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py
index 100f47ca9e..089ec71c81 100644
--- a/synapse/rest/client/v2_alpha/groups.py
+++ b/synapse/rest/client/v2_alpha/groups.py
@@ -39,20 +39,23 @@ class GroupServlet(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, group_id):
requester = yield self.auth.get_user_by_req(request)
- user_id = requester.user.to_string()
+ requester_user_id = requester.user.to_string()
- group_description = yield self.groups_handler.get_group_profile(group_id, user_id)
+ group_description = yield self.groups_handler.get_group_profile(
+ group_id,
+ requester_user_id,
+ )
defer.returnValue((200, group_description))
@defer.inlineCallbacks
def on_POST(self, request, group_id):
requester = yield self.auth.get_user_by_req(request)
- user_id = requester.user.to_string()
+ requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
yield self.groups_handler.update_group_profile(
- group_id, user_id, content,
+ group_id, requester_user_id, content,
)
defer.returnValue((200, {}))
@@ -72,9 +75,12 @@ class GroupSummaryServlet(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, group_id):
requester = yield self.auth.get_user_by_req(request)
- user_id = requester.user.to_string()
+ requester_user_id = requester.user.to_string()
- get_group_summary = yield self.groups_handler.get_group_summary(group_id, user_id)
+ get_group_summary = yield self.groups_handler.get_group_summary(
+ group_id,
+ requester_user_id,
+ )
defer.returnValue((200, get_group_summary))
@@ -101,11 +107,11 @@ class GroupSummaryRoomsCatServlet(RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, group_id, category_id, room_id):
requester = yield self.auth.get_user_by_req(request)
- user_id = requester.user.to_string()
+ requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
resp = yield self.groups_handler.update_group_summary_room(
- group_id, user_id,
+ group_id, requester_user_id,
room_id=room_id,
category_id=category_id,
content=content,
@@ -116,10 +122,10 @@ class GroupSummaryRoomsCatServlet(RestServlet):
@defer.inlineCallbacks
def on_DELETE(self, request, group_id, category_id, room_id):
requester = yield self.auth.get_user_by_req(request)
- user_id = requester.user.to_string()
+ requester_user_id = requester.user.to_string()
resp = yield self.groups_handler.delete_group_summary_room(
- group_id, user_id,
+ group_id, requester_user_id,
room_id=room_id,
category_id=category_id,
)
@@ -143,10 +149,10 @@ class GroupCategoryServlet(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, group_id, category_id):
requester = yield self.auth.get_user_by_req(request)
- user_id = requester.user.to_string()
+ requester_user_id = requester.user.to_string()
category = yield self.groups_handler.get_group_category(
- group_id, user_id,
+ group_id, requester_user_id,
category_id=category_id,
)
@@ -155,11 +161,11 @@ class GroupCategoryServlet(RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, group_id, category_id):
requester = yield self.auth.get_user_by_req(request)
- user_id = requester.user.to_string()
+ requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
resp = yield self.groups_handler.update_group_category(
- group_id, user_id,
+ group_id, requester_user_id,
category_id=category_id,
content=content,
)
@@ -169,10 +175,10 @@ class GroupCategoryServlet(RestServlet):
@defer.inlineCallbacks
def on_DELETE(self, request, group_id, category_id):
requester = yield self.auth.get_user_by_req(request)
- user_id = requester.user.to_string()
+ requester_user_id = requester.user.to_string()
resp = yield self.groups_handler.delete_group_category(
- group_id, user_id,
+ group_id, requester_user_id,
category_id=category_id,
)
@@ -195,10 +201,10 @@ class GroupCategoriesServlet(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, group_id):
requester = yield self.auth.get_user_by_req(request)
- user_id = requester.user.to_string()
+ requester_user_id = requester.user.to_string()
category = yield self.groups_handler.get_group_categories(
- group_id, user_id,
+ group_id, requester_user_id,
)
defer.returnValue((200, category))
@@ -220,10 +226,10 @@ class GroupRoleServlet(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, group_id, role_id):
requester = yield self.auth.get_user_by_req(request)
- user_id = requester.user.to_string()
+ requester_user_id = requester.user.to_string()
category = yield self.groups_handler.get_group_role(
- group_id, user_id,
+ group_id, requester_user_id,
role_id=role_id,
)
@@ -232,11 +238,11 @@ class GroupRoleServlet(RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, group_id, role_id):
requester = yield self.auth.get_user_by_req(request)
- user_id = requester.user.to_string()
+ requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
resp = yield self.groups_handler.update_group_role(
- group_id, user_id,
+ group_id, requester_user_id,
role_id=role_id,
content=content,
)
@@ -246,10 +252,10 @@ class GroupRoleServlet(RestServlet):
@defer.inlineCallbacks
def on_DELETE(self, request, group_id, role_id):
requester = yield self.auth.get_user_by_req(request)
- user_id = requester.user.to_string()
+ requester_user_id = requester.user.to_string()
resp = yield self.groups_handler.delete_group_role(
- group_id, user_id,
+ group_id, requester_user_id,
role_id=role_id,
)
@@ -272,10 +278,10 @@ class GroupRolesServlet(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, group_id):
requester = yield self.auth.get_user_by_req(request)
- user_id = requester.user.to_string()
+ requester_user_id = requester.user.to_string()
category = yield self.groups_handler.get_group_roles(
- group_id, user_id,
+ group_id, requester_user_id,
)
defer.returnValue((200, category))
@@ -343,9 +349,9 @@ class GroupRoomServlet(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, group_id):
requester = yield self.auth.get_user_by_req(request)
- user_id = requester.user.to_string()
+ requester_user_id = requester.user.to_string()
- result = yield self.groups_handler.get_rooms_in_group(group_id, user_id)
+ result = yield self.groups_handler.get_rooms_in_group(group_id, requester_user_id)
defer.returnValue((200, result))
@@ -364,9 +370,9 @@ class GroupUsersServlet(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, group_id):
requester = yield self.auth.get_user_by_req(request)
- user_id = requester.user.to_string()
+ requester_user_id = requester.user.to_string()
- result = yield self.groups_handler.get_users_in_group(group_id, user_id)
+ result = yield self.groups_handler.get_users_in_group(group_id, requester_user_id)
defer.returnValue((200, result))
@@ -385,9 +391,12 @@ class GroupInvitedUsersServlet(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, group_id):
requester = yield self.auth.get_user_by_req(request)
- user_id = requester.user.to_string()
+ requester_user_id = requester.user.to_string()
- result = yield self.groups_handler.get_invited_users_in_group(group_id, user_id)
+ result = yield self.groups_handler.get_invited_users_in_group(
+ group_id,
+ requester_user_id,
+ )
defer.returnValue((200, result))
@@ -407,14 +416,18 @@ class GroupCreateServlet(RestServlet):
@defer.inlineCallbacks
def on_POST(self, request):
requester = yield self.auth.get_user_by_req(request)
- user_id = requester.user.to_string()
+ requester_user_id = requester.user.to_string()
# TODO: Create group on remote server
content = parse_json_object_from_request(request)
localpart = content.pop("localpart")
group_id = GroupID(localpart, self.server_name).to_string()
- result = yield self.groups_handler.create_group(group_id, user_id, content)
+ result = yield self.groups_handler.create_group(
+ group_id,
+ requester_user_id,
+ content,
+ )
defer.returnValue((200, result))
@@ -435,11 +448,11 @@ class GroupAdminRoomsServlet(RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, group_id, room_id):
requester = yield self.auth.get_user_by_req(request)
- user_id = requester.user.to_string()
+ requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
result = yield self.groups_handler.add_room_to_group(
- group_id, user_id, room_id, content,
+ group_id, requester_user_id, room_id, content,
)
defer.returnValue((200, result))
@@ -447,10 +460,37 @@ class GroupAdminRoomsServlet(RestServlet):
@defer.inlineCallbacks
def on_DELETE(self, request, group_id, room_id):
requester = yield self.auth.get_user_by_req(request)
- user_id = requester.user.to_string()
+ requester_user_id = requester.user.to_string()
result = yield self.groups_handler.remove_room_from_group(
- group_id, user_id, room_id,
+ group_id, requester_user_id, room_id,
+ )
+
+ defer.returnValue((200, result))
+
+
+class GroupAdminRoomsConfigServlet(RestServlet):
+ """Update the config of a room in a group
+ """
+ PATTERNS = client_v2_patterns(
+ "/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)"
+ "/config/(?P<config_key>[^/]*)$"
+ )
+
+ def __init__(self, hs):
+ super(GroupAdminRoomsConfigServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_PUT(self, request, group_id, room_id, config_key):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ content = parse_json_object_from_request(request)
+ result = yield self.groups_handler.update_room_in_group(
+ group_id, requester_user_id, room_id, config_key, content,
)
defer.returnValue((200, result))
@@ -685,9 +725,9 @@ class GroupsForUserServlet(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request):
requester = yield self.auth.get_user_by_req(request)
- user_id = requester.user.to_string()
+ requester_user_id = requester.user.to_string()
- result = yield self.groups_handler.get_joined_groups(user_id)
+ result = yield self.groups_handler.get_joined_groups(requester_user_id)
defer.returnValue((200, result))
@@ -700,6 +740,7 @@ def register_servlets(hs, http_server):
GroupRoomServlet(hs).register(http_server)
GroupCreateServlet(hs).register(http_server)
GroupAdminRoomsServlet(hs).register(http_server)
+ GroupAdminRoomsConfigServlet(hs).register(http_server)
GroupAdminUsersInviteServlet(hs).register(http_server)
GroupAdminUsersKickServlet(hs).register(http_server)
GroupSelfLeaveServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
index 943e87e7fd..3cc87ea63f 100644
--- a/synapse/rest/client/v2_alpha/keys.py
+++ b/synapse/rest/client/v2_alpha/keys.py
@@ -53,8 +53,7 @@ class KeyUploadServlet(RestServlet):
},
}
"""
- PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$",
- releases=())
+ PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$")
def __init__(self, hs):
"""
@@ -128,10 +127,7 @@ class KeyQueryServlet(RestServlet):
} } } } } }
"""
- PATTERNS = client_v2_patterns(
- "/keys/query$",
- releases=()
- )
+ PATTERNS = client_v2_patterns("/keys/query$")
def __init__(self, hs):
"""
@@ -160,10 +156,7 @@ class KeyChangesServlet(RestServlet):
200 OK
{ "changed": ["@foo:example.com"] }
"""
- PATTERNS = client_v2_patterns(
- "/keys/changes$",
- releases=()
- )
+ PATTERNS = client_v2_patterns("/keys/changes$")
def __init__(self, hs):
"""
@@ -213,10 +206,7 @@ class OneTimeKeyServlet(RestServlet):
} } } }
"""
- PATTERNS = client_v2_patterns(
- "/keys/claim$",
- releases=()
- )
+ PATTERNS = client_v2_patterns("/keys/claim$")
def __init__(self, hs):
super(OneTimeKeyServlet, self).__init__()
diff --git a/synapse/rest/client/v2_alpha/notifications.py b/synapse/rest/client/v2_alpha/notifications.py
index fd2a3d69d4..ec170109fe 100644
--- a/synapse/rest/client/v2_alpha/notifications.py
+++ b/synapse/rest/client/v2_alpha/notifications.py
@@ -30,7 +30,7 @@ logger = logging.getLogger(__name__)
class NotificationsServlet(RestServlet):
- PATTERNS = client_v2_patterns("/notifications$", releases=())
+ PATTERNS = client_v2_patterns("/notifications$")
def __init__(self, hs):
super(NotificationsServlet, self).__init__()
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index d9a8cdbbb5..9e2f7308ce 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -224,6 +224,12 @@ class RegisterRestServlet(RestServlet):
# 'user' key not 'username'). Since this is a new addition, we'll
# fallback to 'username' if they gave one.
desired_username = body.get("user", desired_username)
+
+ # XXX we should check that desired_username is valid. Currently
+ # we give appservices carte blanche for any insanity in mxids,
+ # because the IRC bridges rely on being able to register stupid
+ # IDs.
+
access_token = get_access_token_from_request(request)
if isinstance(desired_username, basestring):
@@ -233,6 +239,15 @@ class RegisterRestServlet(RestServlet):
defer.returnValue((200, result)) # we throw for non 200 responses
return
+ # for either shared secret or regular registration, downcase the
+ # provided username before attempting to register it. This should mean
+ # that people who try to register with upper-case in their usernames
+ # don't get a nasty surprise. (Note that we treat username
+ # case-insenstively in login, so they are free to carry on imagining
+ # that their username is CrAzYh4cKeR if that keeps them happy)
+ if desired_username is not None:
+ desired_username = desired_username.lower()
+
# == Shared Secret Registration == (e.g. create new user scripts)
if 'mac' in body:
# FIXME: Should we really be determining if this is shared secret
@@ -336,6 +351,9 @@ class RegisterRestServlet(RestServlet):
new_password = params.get("password", None)
guest_access_token = params.get("guest_access_token", None)
+ if desired_username is not None:
+ desired_username = desired_username.lower()
+
(registered_user_id, _) = yield self.registration_handler.register(
localpart=desired_username,
password=new_password,
@@ -417,13 +435,22 @@ class RegisterRestServlet(RestServlet):
def _do_shared_secret_registration(self, username, password, body):
if not self.hs.config.registration_shared_secret:
raise SynapseError(400, "Shared secret registration is not enabled")
+ if not username:
+ raise SynapseError(
+ 400, "username must be specified", errcode=Codes.BAD_JSON,
+ )
- user = username.encode("utf-8")
+ # use the username from the original request rather than the
+ # downcased one in `username` for the mac calculation
+ user = body["username"].encode("utf-8")
# str() because otherwise hmac complains that 'unicode' does not
# have the buffer interface
got_mac = str(body["mac"])
+ # FIXME this is different to the /v1/register endpoint, which
+ # includes the password and admin flag in the hashed text. Why are
+ # these different?
want_mac = hmac.new(
key=self.hs.config.registration_shared_secret,
msg=user,
@@ -557,25 +584,28 @@ class RegisterRestServlet(RestServlet):
Args:
(str) user_id: full canonical @user:id
(object) params: registration parameters, from which we pull
- device_id and initial_device_name
+ device_id, initial_device_name and inhibit_login
Returns:
defer.Deferred: (object) dictionary for response from /register
"""
- device_id = yield self._register_device(user_id, params)
+ result = {
+ "user_id": user_id,
+ "home_server": self.hs.hostname,
+ }
+ if not params.get("inhibit_login", False):
+ device_id = yield self._register_device(user_id, params)
- access_token = (
- yield self.auth_handler.get_access_token_for_user_id(
- user_id, device_id=device_id,
- initial_display_name=params.get("initial_device_display_name")
+ access_token = (
+ yield self.auth_handler.get_access_token_for_user_id(
+ user_id, device_id=device_id,
+ )
)
- )
- defer.returnValue({
- "user_id": user_id,
- "access_token": access_token,
- "home_server": self.hs.hostname,
- "device_id": device_id,
- })
+ result.update({
+ "access_token": access_token,
+ "device_id": device_id,
+ })
+ defer.returnValue(result)
def _register_device(self, user_id, params):
"""Register a device for a user.
diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py
index d607bd2970..90bdb1db15 100644
--- a/synapse/rest/client/v2_alpha/sendtodevice.py
+++ b/synapse/rest/client/v2_alpha/sendtodevice.py
@@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
class SendToDeviceRestServlet(servlet.RestServlet):
PATTERNS = client_v2_patterns(
"/sendToDevice/(?P<message_type>[^/]*)/(?P<txn_id>[^/]*)$",
- releases=[], v2_alpha=False
+ v2_alpha=False
)
def __init__(self, hs):
diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py
index 6fceb23e26..6773b9ba60 100644
--- a/synapse/rest/client/v2_alpha/thirdparty.py
+++ b/synapse/rest/client/v2_alpha/thirdparty.py
@@ -26,7 +26,7 @@ logger = logging.getLogger(__name__)
class ThirdPartyProtocolsServlet(RestServlet):
- PATTERNS = client_v2_patterns("/thirdparty/protocols", releases=())
+ PATTERNS = client_v2_patterns("/thirdparty/protocols")
def __init__(self, hs):
super(ThirdPartyProtocolsServlet, self).__init__()
@@ -43,8 +43,7 @@ class ThirdPartyProtocolsServlet(RestServlet):
class ThirdPartyProtocolServlet(RestServlet):
- PATTERNS = client_v2_patterns("/thirdparty/protocol/(?P<protocol>[^/]+)$",
- releases=())
+ PATTERNS = client_v2_patterns("/thirdparty/protocol/(?P<protocol>[^/]+)$")
def __init__(self, hs):
super(ThirdPartyProtocolServlet, self).__init__()
@@ -66,8 +65,7 @@ class ThirdPartyProtocolServlet(RestServlet):
class ThirdPartyUserServlet(RestServlet):
- PATTERNS = client_v2_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$",
- releases=())
+ PATTERNS = client_v2_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$")
def __init__(self, hs):
super(ThirdPartyUserServlet, self).__init__()
@@ -90,8 +88,7 @@ class ThirdPartyUserServlet(RestServlet):
class ThirdPartyLocationServlet(RestServlet):
- PATTERNS = client_v2_patterns("/thirdparty/location(/(?P<protocol>[^/]+))?$",
- releases=())
+ PATTERNS = client_v2_patterns("/thirdparty/location(/(?P<protocol>[^/]+))?$")
def __init__(self, hs):
super(ThirdPartyLocationServlet, self).__init__()
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 80114fca0d..723f7043f4 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -20,6 +20,7 @@ from twisted.web.resource import Resource
from synapse.api.errors import (
SynapseError, Codes,
)
+from synapse.util.logcontext import preserve_fn, make_deferred_yieldable
from synapse.util.stringutils import random_string
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.http.client import SpiderHttpClient
@@ -63,16 +64,15 @@ class PreviewUrlResource(Resource):
self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist
- # simple memory cache mapping urls to OG metadata
- self.cache = ExpiringCache(
+ # memory cache mapping urls to an ObservableDeferred returning
+ # JSON-encoded OG metadata
+ self._cache = ExpiringCache(
cache_name="url_previews",
clock=self.clock,
# don't spider URLs more often than once an hour
expiry_ms=60 * 60 * 1000,
)
- self.cache.start()
-
- self.downloads = {}
+ self._cache.start()
self._cleaner_loop = self.clock.looping_call(
self._expire_url_cache_data, 10 * 1000
@@ -94,6 +94,7 @@ class PreviewUrlResource(Resource):
else:
ts = self.clock.time_msec()
+ # XXX: we could move this into _do_preview if we wanted.
url_tuple = urlparse.urlsplit(url)
for entry in self.url_preview_url_blacklist:
match = True
@@ -126,14 +127,42 @@ class PreviewUrlResource(Resource):
Codes.UNKNOWN
)
- # first check the memory cache - good to handle all the clients on this
- # HS thundering away to preview the same URL at the same time.
- og = self.cache.get(url)
- if og:
- respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True)
- return
+ # the in-memory cache:
+ # * ensures that only one request is active at a time
+ # * takes load off the DB for the thundering herds
+ # * also caches any failures (unlike the DB) so we don't keep
+ # requesting the same endpoint
+
+ observable = self._cache.get(url)
+
+ if not observable:
+ download = preserve_fn(self._do_preview)(
+ url, requester.user, ts,
+ )
+ observable = ObservableDeferred(
+ download,
+ consumeErrors=True
+ )
+ self._cache[url] = observable
+ else:
+ logger.info("Returning cached response")
+
+ og = yield make_deferred_yieldable(observable.observe())
+ respond_with_json_bytes(request, 200, og, send_cors=True)
- # then check the URL cache in the DB (which will also provide us with
+ @defer.inlineCallbacks
+ def _do_preview(self, url, user, ts):
+ """Check the db, and download the URL and build a preview
+
+ Args:
+ url (str):
+ user (str):
+ ts (int):
+
+ Returns:
+ Deferred[str]: json-encoded og data
+ """
+ # check the URL cache in the DB (which will also provide us with
# historical previews, if we have any)
cache_result = yield self.store.get_url_cache(url, ts)
if (
@@ -141,32 +170,10 @@ class PreviewUrlResource(Resource):
cache_result["expires_ts"] > ts and
cache_result["response_code"] / 100 == 2
):
- respond_with_json_bytes(
- request, 200, cache_result["og"].encode('utf-8'),
- send_cors=True
- )
+ defer.returnValue(cache_result["og"])
return
- # Ensure only one download for a given URL is active at a time
- download = self.downloads.get(url)
- if download is None:
- download = self._download_url(url, requester.user)
- download = ObservableDeferred(
- download,
- consumeErrors=True
- )
- self.downloads[url] = download
-
- @download.addBoth
- def callback(media_info):
- del self.downloads[url]
- return media_info
- media_info = yield download.observe()
-
- # FIXME: we should probably update our cache now anyway, so that
- # even if the OG calculation raises, we don't keep hammering on the
- # remote server. For now, leave it uncached to aid debugging OG
- # calculation problems
+ media_info = yield self._download_url(url, user)
logger.debug("got media_info of '%s'" % media_info)
@@ -212,7 +219,7 @@ class PreviewUrlResource(Resource):
# just rely on the caching on the master request to speed things up.
if 'og:image' in og and og['og:image']:
image_info = yield self._download_url(
- _rebase_url(og['og:image'], media_info['uri']), requester.user
+ _rebase_url(og['og:image'], media_info['uri']), user
)
if _is_media(image_info['media_type']):
@@ -239,8 +246,7 @@ class PreviewUrlResource(Resource):
logger.debug("Calculated OG for %s as %s" % (url, og))
- # store OG in ephemeral in-memory cache
- self.cache[url] = og
+ jsonog = json.dumps(og)
# store OG in history-aware DB cache
yield self.store.store_url_cache(
@@ -248,12 +254,12 @@ class PreviewUrlResource(Resource):
media_info["response_code"],
media_info["etag"],
media_info["expires"] + media_info["created_ts"],
- json.dumps(og),
+ jsonog,
media_info["filesystem_id"],
media_info["created_ts"],
)
- respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True)
+ defer.returnValue(jsonog)
@defer.inlineCallbacks
def _download_url(self, url, user):
@@ -520,7 +526,14 @@ def _calc_og(tree, media_uri):
from lxml import etree
TAGS_TO_REMOVE = (
- "header", "nav", "aside", "footer", "script", "style", etree.Comment
+ "header",
+ "nav",
+ "aside",
+ "footer",
+ "script",
+ "noscript",
+ "style",
+ etree.Comment
)
# Split all the text nodes into paragraphs (by splitting on new
|