diff options
Diffstat (limited to 'synapse/rest')
-rw-r--r-- | synapse/rest/base.py | 2 | ||||
-rw-r--r-- | synapse/rest/register.py | 4 | ||||
-rw-r--r-- | synapse/rest/room.py | 217 | ||||
-rw-r--r-- | synapse/rest/transactions.py | 96 |
4 files changed, 226 insertions, 93 deletions
diff --git a/synapse/rest/base.py b/synapse/rest/base.py index 6a88cbe866..e855d293e5 100644 --- a/synapse/rest/base.py +++ b/synapse/rest/base.py @@ -15,6 +15,7 @@ """ This module contains base REST classes for constructing REST servlets. """ from synapse.api.urls import CLIENT_PREFIX +from synapse.rest.transactions import HttpTransactionStore import re @@ -59,6 +60,7 @@ class RestServlet(object): self.handlers = hs.get_handlers() self.event_factory = hs.get_event_factory() self.auth = hs.get_auth() + self.txns = HttpTransactionStore() def register(self, http_server): """ Register this servlet with the given HTTP server. """ diff --git a/synapse/rest/register.py b/synapse/rest/register.py index eb457562b9..f17ec11cf4 100644 --- a/synapse/rest/register.py +++ b/synapse/rest/register.py @@ -33,10 +33,10 @@ class RegisterRestServlet(RestServlet): try: register_json = json.loads(request.content.read()) if "password" in register_json: - password = register_json["password"] + password = register_json["password"].encode("utf-8") if type(register_json["user_id"]) == unicode: - desired_user_id = register_json["user_id"] + desired_user_id = register_json["user_id"].encode("utf-8") if urllib.quote(desired_user_id) != desired_user_id: raise SynapseError( 400, diff --git a/synapse/rest/room.py b/synapse/rest/room.py index f5b547b963..6771da8fcd 100644 --- a/synapse/rest/room.py +++ b/synapse/rest/room.py @@ -18,9 +18,10 @@ from twisted.internet import defer from base import RestServlet, client_path_pattern from synapse.api.errors import SynapseError, Codes -from synapse.api.events.room import (RoomTopicEvent, MessageEvent, - RoomMemberEvent, FeedbackEvent) -from synapse.api.constants import Feedback, Membership +from synapse.api.events.room import ( + MessageEvent, RoomMemberEvent, FeedbackEvent +) +from synapse.api.constants import Feedback from synapse.api.streams import PaginationConfig import json @@ -95,46 +96,76 @@ class RoomCreateRestServlet(RestServlet): return (200, {}) -class RoomTopicRestServlet(RestServlet): - PATTERN = client_path_pattern("/rooms/(?P<room_id>[^/]*)/topic$") +class RoomStateEventRestServlet(RestServlet): + def register(self, http_server): + # /room/$roomid/state/$eventtype + no_state_key = "/rooms/(?P<room_id>[^/]*)/state/(?P<event_type>[^/]*)$" - def get_event_type(self): - return RoomTopicEvent.TYPE + # /room/$roomid/state/$eventtype/$statekey + state_key = ("/rooms/(?P<room_id>[^/]*)/state/" + + "(?P<event_type>[^/]*)/(?P<state_key>[^/]*)$") + + http_server.register_path("GET", + client_path_pattern(state_key), + self.on_GET) + http_server.register_path("PUT", + client_path_pattern(state_key), + self.on_PUT) + http_server.register_path("GET", + client_path_pattern(no_state_key), + self.on_GET_no_state_key) + http_server.register_path("PUT", + client_path_pattern(no_state_key), + self.on_PUT_no_state_key) + + def on_GET_no_state_key(self, request, room_id, event_type): + return self.on_GET(request, room_id, event_type, "") + + def on_PUT_no_state_key(self, request, room_id, event_type): + return self.on_PUT(request, room_id, event_type, "") @defer.inlineCallbacks - def on_GET(self, request, room_id): + def on_GET(self, request, room_id, event_type, state_key): user = yield self.auth.get_user_by_req(request) msg_handler = self.handlers.message_handler data = yield msg_handler.get_room_data( user_id=user.to_string(), room_id=urllib.unquote(room_id), - event_type=RoomTopicEvent.TYPE, - state_key="", + event_type=urllib.unquote(event_type), + state_key=urllib.unquote(state_key), ) if not data: - raise SynapseError(404, "Topic not found.", errcode=Codes.NOT_FOUND) - defer.returnValue((200, data.content)) + raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) + defer.returnValue((200, data[0].get_dict()["content"])) @defer.inlineCallbacks - def on_PUT(self, request, room_id): + def on_PUT(self, request, room_id, event_type, state_key): user = yield self.auth.get_user_by_req(request) + event_type = urllib.unquote(event_type) content = _parse_json(request) event = self.event_factory.create_event( - etype=self.get_event_type(), + etype=event_type, content=content, room_id=urllib.unquote(room_id), user_id=user.to_string(), + state_key=urllib.unquote(state_key) ) - - msg_handler = self.handlers.message_handler - yield msg_handler.store_room_data( - event=event - ) - defer.returnValue((200, "")) + if event_type == RoomMemberEvent.TYPE: + # membership events are special + handler = self.handlers.room_member_handler + yield handler.change_membership(event) + defer.returnValue((200, "")) + else: + # store random bits of state + msg_handler = self.handlers.message_handler + yield msg_handler.store_room_data( + event=event + ) + defer.returnValue((200, "")) class JoinRoomAliasServlet(RestServlet): @@ -157,73 +188,6 @@ class JoinRoomAliasServlet(RestServlet): defer.returnValue((200, ret_dict)) -class RoomMemberRestServlet(RestServlet): - PATTERN = client_path_pattern("/rooms/(?P<room_id>[^/]*)/members/" - + "(?P<target_user_id>[^/]*)/state$") - - def get_event_type(self): - return RoomMemberEvent.TYPE - - @defer.inlineCallbacks - def on_GET(self, request, room_id, target_user_id): - room_id = urllib.unquote(room_id) - user = yield self.auth.get_user_by_req(request) - - handler = self.handlers.room_member_handler - member = yield handler.get_room_member( - room_id, - urllib.unquote(target_user_id), - user.to_string()) - if not member: - raise SynapseError(404, "Member not found.", - errcode=Codes.NOT_FOUND) - defer.returnValue((200, member.content)) - - @defer.inlineCallbacks - def on_DELETE(self, request, roomid, target_user_id): - user = yield self.auth.get_user_by_req(request) - - event = self.event_factory.create_event( - etype=self.get_event_type(), - target_user_id=urllib.unquote(target_user_id), - room_id=urllib.unquote(roomid), - user_id=user.to_string(), - membership=Membership.LEAVE, - content={"membership": Membership.LEAVE} - ) - - handler = self.handlers.room_member_handler - yield handler.change_membership(event) - defer.returnValue((200, "")) - - @defer.inlineCallbacks - def on_PUT(self, request, roomid, target_user_id): - user = yield self.auth.get_user_by_req(request) - - content = _parse_json(request) - if "membership" not in content: - raise SynapseError(400, "No membership key.", - errcode=Codes.BAD_JSON) - - valid_membership_values = [Membership.JOIN, Membership.INVITE] - if (content["membership"] not in valid_membership_values): - raise SynapseError(400, "Membership value must be %s." % ( - valid_membership_values,), errcode=Codes.BAD_JSON) - - event = self.event_factory.create_event( - etype=self.get_event_type(), - target_user_id=urllib.unquote(target_user_id), - room_id=urllib.unquote(roomid), - user_id=user.to_string(), - membership=content["membership"], - content=content - ) - - handler = self.handlers.room_member_handler - yield handler.change_membership(event) - defer.returnValue((200, "")) - - class MessageRestServlet(RestServlet): PATTERN = client_path_pattern("/rooms/(?P<room_id>[^/]*)/messages/" + "(?P<sender_id>[^/]*)/(?P<msg_id>[^/]*)$") @@ -285,7 +249,7 @@ class FeedbackRestServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id, msg_sender_id, msg_id, fb_sender_id, feedback_type): - user = yield (self.auth.get_user_by_req(request)) + yield (self.auth.get_user_by_req(request)) # TODO (erikj): Implement this? raise NotImplementedError("Getting feedback is not supported") @@ -354,7 +318,8 @@ class RoomMemberListRestServlet(RestServlet): user_id=user.to_string()) for event in members["chunk"]: - target_user = self.hs.parse_userid(event["target_user_id"]) + # FIXME: should probably be state_key here, not user_id + target_user = self.hs.parse_userid(event["user_id"]) # Presence is an optional cache; don't fail if we can't fetch it try: presence_state = yield self.handlers.presence_handler.get_state( @@ -400,6 +365,52 @@ class RoomTriggerBackfill(RestServlet): res = [event.get_dict() for event in events] defer.returnValue((200, res)) + +class RoomMembershipRestServlet(RestServlet): + + def register(self, http_server): + # /rooms/$roomid/[invite|join|leave] + PATTERN = ("/rooms/(?P<room_id>[^/]*)/" + + "(?P<membership_action>join|invite|leave)") + register_txn_path(self, PATTERN, http_server) + + @defer.inlineCallbacks + def on_POST(self, request, room_id, membership_action): + user = yield self.auth.get_user_by_req(request) + + content = _parse_json(request) + + # target user is you unless it is an invite + state_key = user.to_string() + if membership_action == "invite": + if "user_id" not in content: + raise SynapseError(400, "Missing user_id key.") + state_key = content["user_id"] + + event = self.event_factory.create_event( + etype=RoomMemberEvent.TYPE, + content={"membership": unicode(membership_action)}, + room_id=urllib.unquote(room_id), + user_id=user.to_string(), + state_key=state_key + ) + handler = self.handlers.room_member_handler + yield handler.change_membership(event) + defer.returnValue((200, "")) + + @defer.inlineCallbacks + def on_PUT(self, request, room_id, membership_action, txn_id): + try: + defer.returnValue(self.txns.get_client_transaction(request, txn_id)) + except: + pass + + response = yield self.on_POST(request, room_id, membership_action) + + self.txns.store_client_transaction(request, txn_id, response) + defer.returnValue(response) + + def _parse_json(request): try: content = json.loads(request.content.read()) @@ -411,9 +422,32 @@ def _parse_json(request): raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) +def register_txn_path(servlet, regex_string, http_server): + """Registers a transaction-based path. + + This registers two paths: + PUT regex_string/$txnid + POST regex_string + + Args: + regex_string (str): The regex string to register. Must NOT have a + trailing $ as this string will be appended to. + http_server : The http_server to register paths with. + """ + http_server.register_path( + "POST", + client_path_pattern(regex_string + "$"), + servlet.on_POST + ) + http_server.register_path( + "PUT", + client_path_pattern(regex_string + "/(?P<txn_id>[^/]*)$"), + servlet.on_PUT + ) + + def register_servlets(hs, http_server): - RoomTopicRestServlet(hs).register(http_server) - RoomMemberRestServlet(hs).register(http_server) + RoomStateEventRestServlet(hs).register(http_server) MessageRestServlet(hs).register(http_server) FeedbackRestServlet(hs).register(http_server) RoomCreateRestServlet(hs).register(http_server) @@ -421,3 +455,4 @@ def register_servlets(hs, http_server): RoomMessageListRestServlet(hs).register(http_server) JoinRoomAliasServlet(hs).register(http_server) RoomTriggerBackfill(hs).register(http_server) + RoomMembershipRestServlet(hs).register(http_server) diff --git a/synapse/rest/transactions.py b/synapse/rest/transactions.py new file mode 100644 index 0000000000..b8aa1ef11c --- /dev/null +++ b/synapse/rest/transactions.py @@ -0,0 +1,96 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This module contains logic for storing HTTP PUT transactions. This is used +to ensure idempotency when performing PUTs using the REST API.""" +import logging + +logger = logging.getLogger(__name__) + + +class HttpTransactionStore(object): + + def __init__(self): + # { key : (txn_id, response) } + self.transactions = {} + + def get_response(self, key, txn_id): + """Retrieve a response for this request. + + Args: + key (str): A transaction-independent key for this request. Typically + this is a combination of the path (without the transaction id) and + the user's access token. + txn_id (str): The transaction ID for this request + Returns: + A tuple of (HTTP response code, response content) or None. + """ + try: + logger.debug("get_response Key: %s TxnId: %s", key, txn_id) + (last_txn_id, response) = self.transactions[key] + if txn_id == last_txn_id: + logger.info("get_response: Returning a response for %s", key) + return response + except KeyError: + pass + return None + + def store_response(self, key, txn_id, response): + """Stores an HTTP response tuple. + + Args: + key (str): A transaction-independent key for this request. Typically + this is a combination of the path (without the transaction id) and + the user's access token. + txn_id (str): The transaction ID for this request. + response (tuple): A tuple of (HTTP response code, response content) + """ + logger.debug("store_response Key: %s TxnId: %s", key, txn_id) + self.transactions[key] = (txn_id, response) + + def store_client_transaction(self, request, txn_id, response): + """Stores the request/response pair of an HTTP transaction. + + Args: + request (twisted.web.http.Request): The twisted HTTP request. This + request must have the transaction ID as the last path segment. + response (tuple): A tuple of (response code, response dict) + txn_id (str): The transaction ID for this request. + """ + self.store_response(self._get_key(request), txn_id, response) + + def get_client_transaction(self, request, txn_id): + """Retrieves a stored response if there was one. + + Args: + request (twisted.web.http.Request): The twisted HTTP request. This + request must have the transaction ID as the last path segment. + txn_id (str): The transaction ID for this request. + Returns: + The response tuple. + Raises: + KeyError if the transaction was not found. + """ + response = self.get_response(self._get_key(request), txn_id) + if response is None: + raise KeyError("Transaction not found.") + return response + + def _get_key(self, request): + token = request.args["access_token"][0] + path_without_txn_id = request.path.rsplit("/", 1)[0] + return path_without_txn_id + "/" + token + + |