summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/rest/client/v1/base.py4
-rw-r--r--synapse/rest/client/v1/room.py82
-rw-r--r--synapse/rest/client/v1/transactions.py52
-rw-r--r--synapse/rest/client/v2_alpha/sendtodevice.py18
4 files changed, 68 insertions, 88 deletions
diff --git a/synapse/rest/client/v1/base.py b/synapse/rest/client/v1/base.py
index c2a8447860..22c740c30c 100644
--- a/synapse/rest/client/v1/base.py
+++ b/synapse/rest/client/v1/base.py
@@ -18,7 +18,7 @@
 
 from synapse.http.servlet import RestServlet
 from synapse.api.urls import CLIENT_PREFIX
-from .transactions import HttpTransactionStore
+from .transactions import HttpTransactionCache
 import re
 
 import logging
@@ -59,4 +59,4 @@ class ClientV1RestServlet(RestServlet):
         self.hs = hs
         self.builder_factory = hs.get_event_builder_factory()
         self.auth = hs.get_v1auth()
-        self.txns = HttpTransactionStore()
+        self.txns = HttpTransactionCache()
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 010fbc7c32..2e919de9f3 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -56,15 +56,15 @@ class RoomCreateRestServlet(ClientV1RestServlet):
     @defer.inlineCallbacks
     def on_PUT(self, request, txn_id):
         try:
-            defer.returnValue(
-                self.txns.get_client_transaction(request, txn_id)
-            )
+            res_deferred = self.txns.get_client_transaction(request, txn_id)
+            res = yield res_deferred
+            defer.returnValue(res)
         except KeyError:
             pass
 
-        response = yield self.on_POST(request)
-
-        self.txns.store_client_transaction(request, txn_id, response)
+        res_deferred = self.on_POST(request)
+        self.txns.store_client_transaction(request, txn_id, res_deferred)
+        response = yield res_deferred
         defer.returnValue(response)
 
     @defer.inlineCallbacks
@@ -217,15 +217,15 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
     @defer.inlineCallbacks
     def on_PUT(self, request, room_id, event_type, txn_id):
         try:
-            defer.returnValue(
-                self.txns.get_client_transaction(request, txn_id)
-            )
+            res_deferred = self.txns.get_client_transaction(request, txn_id)
+            res = yield res_deferred
+            defer.returnValue(res)
         except KeyError:
             pass
-
-        response = yield self.on_POST(request, room_id, event_type, txn_id)
-
-        self.txns.store_client_transaction(request, txn_id, response)
+        
+        res_deferred = self.on_POST(request, room_id, event_type, txn_id)
+        self.txns.store_client_transaction(request, txn_id, res_deferred)
+        response = yield res_deferred
         defer.returnValue(response)
 
 
@@ -286,15 +286,15 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
     @defer.inlineCallbacks
     def on_PUT(self, request, room_identifier, txn_id):
         try:
-            defer.returnValue(
-                self.txns.get_client_transaction(request, txn_id)
-            )
+            res_deferred = self.txns.get_client_transaction(request, txn_id)
+            res = yield res_deferred
+            defer.returnValue(res)
         except KeyError:
             pass
-
-        response = yield self.on_POST(request, room_identifier, txn_id)
-
-        self.txns.store_client_transaction(request, txn_id, response)
+            
+        res_deferred = self.on_POST(request, room_identifier, txn_id)
+        self.txns.store_client_transaction(request, txn_id, res_deferred)
+        response = yield res_deferred
         defer.returnValue(response)
 
 
@@ -540,17 +540,15 @@ class RoomForgetRestServlet(ClientV1RestServlet):
     @defer.inlineCallbacks
     def on_PUT(self, request, room_id, txn_id):
         try:
-            defer.returnValue(
-                self.txns.get_client_transaction(request, txn_id)
-            )
+            res_deferred = self.txns.get_client_transaction(request, txn_id)
+            res = yield res_deferred
+            defer.returnValue(res)
         except KeyError:
             pass
-
-        response = yield self.on_POST(
-            request, room_id, txn_id
-        )
-
-        self.txns.store_client_transaction(request, txn_id, response)
+        
+        res_deferred = self.on_POST(request, room_id, txn_id)
+        self.txns.store_client_transaction(request, txn_id, res_deferred)
+        response = yield res_deferred
         defer.returnValue(response)
 
 
@@ -626,17 +624,15 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
     @defer.inlineCallbacks
     def on_PUT(self, request, room_id, membership_action, txn_id):
         try:
-            defer.returnValue(
-                self.txns.get_client_transaction(request, txn_id)
-            )
+            res_deferred = self.txns.get_client_transaction(request, txn_id)
+            res = yield res_deferred
+            defer.returnValue(res)
         except KeyError:
             pass
 
-        response = yield self.on_POST(
-            request, room_id, membership_action, txn_id
-        )
-
-        self.txns.store_client_transaction(request, txn_id, response)
+        res_deferred = self.on_POST(request, room_id, membership_action, txn_id)
+        self.txns.store_client_transaction(request, txn_id, res_deferred)
+        response = yield res_deferred
         defer.returnValue(response)
 
 
@@ -672,15 +668,15 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
     @defer.inlineCallbacks
     def on_PUT(self, request, room_id, event_id, txn_id):
         try:
-            defer.returnValue(
-                self.txns.get_client_transaction(request, txn_id)
-            )
+            res_deferred = self.txns.get_client_transaction(request, txn_id)
+            res = yield res_deferred
+            defer.returnValue(res)
         except KeyError:
             pass
 
-        response = yield self.on_POST(request, room_id, event_id, txn_id)
-
-        self.txns.store_client_transaction(request, txn_id, response)
+        res_deferred = self.on_POST(request, room_id, event_id, txn_id)
+        self.txns.store_client_transaction(request, txn_id, res_deferred)
+        response = yield res_deferred
         defer.returnValue(response)
 
 
diff --git a/synapse/rest/client/v1/transactions.py b/synapse/rest/client/v1/transactions.py
index 2f2c9d0881..f5012c5f59 100644
--- a/synapse/rest/client/v1/transactions.py
+++ b/synapse/rest/client/v1/transactions.py
@@ -22,57 +22,35 @@ from synapse.api.auth import get_access_token_from_request
 logger = logging.getLogger(__name__)
 
 
-# FIXME: elsewhere we use FooStore to indicate something in the storage layer...
-class HttpTransactionStore(object):
+class HttpTransactionCache(object):
 
     def __init__(self):
-        # { key : (txn_id, response) }
+        # { key : (txn_id, response_deferred) }
         self.transactions = {}
 
-    def get_response(self, key, txn_id):
-        """Retrieve a response for this request.
-
-        Args:
-            key (str): A transaction-independent key for this request. Usually
-                this is a combination of the path (without the transaction id)
-                and the user's access token.
-            txn_id (str): The transaction ID for this request
-        Returns:
-            A tuple of (HTTP response code, response content) or None.
-        """
+    def _get_response(self, key, txn_id):
         try:
-            logger.debug("get_response TxnId: %s", txn_id)
-            (last_txn_id, response) = self.transactions[key]
+            (last_txn_id, response_deferred) = self.transactions[key]
             if txn_id == last_txn_id:
                 logger.info("get_response: Returning a response for %s", txn_id)
-                return response
+                return response_deferred
         except KeyError:
             pass
         return None
 
-    def store_response(self, key, txn_id, response):
-        """Stores an HTTP response tuple.
-
-        Args:
-            key (str): A transaction-independent key for this request. Usually
-                this is a combination of the path (without the transaction id)
-                and the user's access token.
-            txn_id (str): The transaction ID for this request.
-            response (tuple): A tuple of (HTTP response code, response content)
-        """
-        logger.debug("store_response TxnId: %s", txn_id)
-        self.transactions[key] = (txn_id, response)
+    def _store_response(self, key, txn_id, response_deferred):
+        self.transactions[key] = (txn_id, response_deferred)
 
-    def store_client_transaction(self, request, txn_id, response):
-        """Stores the request/response pair of an HTTP transaction.
+    def store_client_transaction(self, request, txn_id, response_deferred):
+        """Stores the request/Promise<response> pair of an HTTP transaction.
 
         Args:
             request (twisted.web.http.Request): The twisted HTTP request. This
             request must have the transaction ID as the last path segment.
-            response (tuple): A tuple of (response code, response dict)
+            response_deferred (Promise<tuple>): A tuple of (response code, response dict)
             txn_id (str): The transaction ID for this request.
         """
-        self.store_response(self._get_key(request), txn_id, response)
+        self._store_response(self._get_key(request), txn_id, response_deferred)
 
     def get_client_transaction(self, request, txn_id):
         """Retrieves a stored response if there was one.
@@ -82,14 +60,14 @@ class HttpTransactionStore(object):
             request must have the transaction ID as the last path segment.
             txn_id (str): The transaction ID for this request.
         Returns:
-            The response tuple.
+            Promise: Resolves to the response tuple.
         Raises:
             KeyError if the transaction was not found.
         """
-        response = self.get_response(self._get_key(request), txn_id)
-        if response is None:
+        response_deferred = self._get_response(self._get_key(request), txn_id)
+        if response_deferred is None:
             raise KeyError("Transaction not found.")
-        return response
+        return response_deferred
 
     def _get_key(self, request):
         token = get_access_token_from_request(request)
diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py
index 5975164b37..7c800ca895 100644
--- a/synapse/rest/client/v2_alpha/sendtodevice.py
+++ b/synapse/rest/client/v2_alpha/sendtodevice.py
@@ -19,7 +19,7 @@ from twisted.internet import defer
 
 from synapse.http import servlet
 from synapse.http.servlet import parse_json_object_from_request
-from synapse.rest.client.v1.transactions import HttpTransactionStore
+from synapse.rest.client.v1.transactions import HttpTransactionCache
 
 from ._base import client_v2_patterns
 
@@ -40,18 +40,25 @@ class SendToDeviceRestServlet(servlet.RestServlet):
         super(SendToDeviceRestServlet, self).__init__()
         self.hs = hs
         self.auth = hs.get_auth()
-        self.txns = HttpTransactionStore()
+        self.txns = HttpTransactionCache()
         self.device_message_handler = hs.get_device_message_handler()
 
     @defer.inlineCallbacks
     def on_PUT(self, request, message_type, txn_id):
         try:
-            defer.returnValue(
-                self.txns.get_client_transaction(request, txn_id)
-            )
+            res_deferred = self.txns.get_client_transaction(request, txn_id)
+            res = yield res_deferred
+            defer.returnValue(res)
         except KeyError:
             pass
+        
+        res_deferred = self._put(request, message_type, txn_id)
+        self.txns.store_client_transaction(request, txn_id, res_deferred)
+        res = yield res_deferred
+        defer.returnValue(res)
 
+    @defer.inlineCallbacks
+    def _put(self, request, message_type, txn_id):
         requester = yield self.auth.get_user_by_req(request)
 
         content = parse_json_object_from_request(request)
@@ -63,7 +70,6 @@ class SendToDeviceRestServlet(servlet.RestServlet):
         )
 
         response = (200, {})
-        self.txns.store_client_transaction(request, txn_id, response)
         defer.returnValue(response)