summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/handlers/e2e_room_keys.py46
-rw-r--r--synapse/rest/client/v2_alpha/room_keys.py37
-rw-r--r--synapse/storage/e2e_room_keys.py20
3 files changed, 84 insertions, 19 deletions
diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py
index 78c838a829..15e3beb5ed 100644
--- a/synapse/handlers/e2e_room_keys.py
+++ b/synapse/handlers/e2e_room_keys.py
@@ -20,8 +20,7 @@ from canonicaljson import encode_canonical_json
 from twisted.internet import defer
 
 from synapse.api.errors import SynapseError, CodeMessageException
-from synapse.types import get_domain_from_id
-from synapse.util.logcontext import preserve_fn, make_deferred_yieldable
+from synapse.util.async import Linearizer
 from synapse.util.retryutils import NotRetryingDestination
 
 logger = logging.getLogger(__name__)
@@ -30,6 +29,7 @@ logger = logging.getLogger(__name__)
 class E2eRoomKeysHandler(object):
     def __init__(self, hs):
         self.store = hs.get_datastore()
+        self._upload_linearizer = async.Linearizer("upload_room_keys_lock")
 
     @defer.inlineCallbacks
     def get_room_keys(self, user_id, version, room_id, session_id):
@@ -37,24 +37,40 @@ class E2eRoomKeysHandler(object):
         defer.returnValue(results)
 
     @defer.inlineCallbacks
+    def delete_room_keys(self, user_id, version, room_id, session_id):
+        yield self.store.delete_e2e_room_keys(user_id, version, room_id, session_id)
+
+    @defer.inlineCallbacks
     def upload_room_keys(self, user_id, version, room_keys):
 
         # TODO: Validate the JSON to make sure it has the right keys.
 
-        # go through the room_keys
-        for room_id in room_keys['rooms']:
-            for session_id in room_keys['rooms'][room_id]['sessions']:
-                session = room_keys['rooms'][room_id]['sessions'][session_id]
-
-                # get a lock
+        # XXX: perhaps we should use a finer grained lock here?
+        with (yield self._upload_linearizer.queue(user_id):
 
-                # get the room_key for this particular row
-                yield self.store.get_e2e_room_key()
+            # go through the room_keys
+            for room_id in room_keys['rooms']:
+                for session_id in room_keys['rooms'][room_id]['sessions']:
+                    room_key = room_keys['rooms'][room_id]['sessions'][session_id]
 
-                # check whether we merge or not
-                if()
+                    # get the room_key for this particular row
+                    current_room_key = yield self.store.get_e2e_room_key(
+                        user_id, version, room_id, session_id
+                    )
 
-                # if so, we set it
-                yield self.store.set_e2e_room_key()
+                    # check whether we merge or not. spelling it out with if/elifs rather than
+                    # lots of booleans for legibility.
+                    replace = False
+                    if current_room_key:
+                        if room_key['is_verified'] and not current_room_key['is_verified']:
+                            replace = True
+                        elif room_key['first_message_index'] < current_room_key['first_message_index']:
+                            replace = True
+                        elif room_key['forwarded_count'] < room_key['forwarded_count']:
+                            replace = True
 
-                # release the lock
+                    # if so, we set the new room_key
+                    if replace:
+                        yield self.store.set_e2e_room_key(
+                            user_id, version, room_id, session_id, room_key
+                        )
diff --git a/synapse/rest/client/v2_alpha/room_keys.py b/synapse/rest/client/v2_alpha/room_keys.py
index 9b93001919..7291018a48 100644
--- a/synapse/rest/client/v2_alpha/room_keys.py
+++ b/synapse/rest/client/v2_alpha/room_keys.py
@@ -28,7 +28,7 @@ from ._base import client_v2_patterns
 logger = logging.getLogger(__name__)
 
 
-class RoomKeysUploadServlet(RestServlet):
+class RoomKeysServlet(RestServlet):
     PATTERNS = client_v2_patterns("/room_keys/keys(/(?P<room_id>[^/]+))?(/(?P<session_id>[^/]+))?$")
 
     def __init__(self, hs):
@@ -41,16 +41,45 @@ class RoomKeysUploadServlet(RestServlet):
         self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler()
 
     @defer.inlineCallbacks
-    def on_POST(self, request, room_id, session_id):
-        requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+    def on_PUT(self, request, room_id, session_id):
+        requester = yield self.auth.get_user_by_req(request, allow_guest=False)
         user_id = requester.user.to_string()
         body = parse_json_object_from_request(request)
+        version = request.args.get("version", None)
+
+        if session_id:
+            body = { "sessions": { session_id : body } }
+
+        if room_id:
+            body = { "rooms": { room_id : body } }
 
         result = yield self.e2e_room_keys_handler.upload_room_keys(
             user_id, version, body
         )
         defer.returnValue((200, result))
 
+    @defer.inlineCallbacks
+    def on_GET(self, request, room_id, session_id):
+        requester = yield self.auth.get_user_by_req(request, allow_guest=False)
+        user_id = requester.user.to_string()
+        version = request.args.get("version", None)
+
+        room_keys = yield self.e2e_room_keys_handler.get_room_keys(
+            user_id, version, room_id, session_id
+        )
+        defer.returnValue((200, room_keys))
+
+    @defer.inlineCallbacks
+    def on_DELETE(self, request, room_id, session_id):
+        requester = yield self.auth.get_user_by_req(request, allow_guest=False)
+        user_id = requester.user.to_string()
+        version = request.args.get("version", None)
+
+        yield self.e2e_room_keys_handler.delete_room_keys(
+            user_id, version, room_id, session_id
+        )
+        defer.returnValue((200, {}))
+
 
 def register_servlets(hs, http_server):
-    RoomKeysUploadServlet(hs).register(http_server)
+    RoomKeysServlet(hs).register(http_server)
diff --git a/synapse/storage/e2e_room_keys.py b/synapse/storage/e2e_room_keys.py
index 9f6d47e1b6..903dc083f8 100644
--- a/synapse/storage/e2e_room_keys.py
+++ b/synapse/storage/e2e_room_keys.py
@@ -12,6 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+
 from twisted.internet import defer
 
 from synapse.util.caches.descriptors import cached
@@ -77,6 +78,9 @@ class EndToEndRoomKeyStore(SQLBaseStore):
         )
 
 
+    # XXX: this isn't currently used and isn't tested anywhere
+    # it could be used in future for bulk-uploading new versions of room_keys
+    # for a user or something though.
     def set_e2e_room_keys(self, user_id, version, room_keys):
 
         def _set_e2e_room_keys_txn(txn):
@@ -131,3 +135,19 @@ class EndToEndRoomKeyStore(SQLBaseStore):
         sessions = {}
         sessions['rooms'][roomId]['sessions'][session_id] = row for row in rows;
         defer.returnValue(sessions);
+
+    @defer.inlineCallbacks
+    def delete_e2e_room_keys(self, user_id, version, room_id, session_id):
+
+        keyvalues={
+            "user_id": user_id,
+            "version": version,
+        }
+        if room_id: keyvalues['room_id'] = room_id
+        if session_id: keyvalues['session_id'] = session_id
+
+        yield self._simple_delete(
+            table="e2e_room_keys",
+            keyvalues=keyvalues,
+            desc="delete_e2e_room_keys",
+        )