summary refs log tree commit diff
path: root/synapse/storage/e2e_room_keys.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/e2e_room_keys.py')
-rw-r--r--synapse/storage/e2e_room_keys.py69
1 files changed, 28 insertions, 41 deletions
diff --git a/synapse/storage/e2e_room_keys.py b/synapse/storage/e2e_room_keys.py
index 8efca11a8c..c11417c415 100644
--- a/synapse/storage/e2e_room_keys.py
+++ b/synapse/storage/e2e_room_keys.py
@@ -44,30 +44,21 @@ class EndToEndRoomKeyStore(SQLBaseStore):
 
     def set_e2e_room_key(self, user_id, version, room_id, session_id, room_key):
 
-        def _set_e2e_room_key_txn(txn):
-
-            self._simple_upsert_txn(
-                txn,
-                table="e2e_room_keys",
-                keyvalues={
-                    "user_id": user_id,
-                    "room_id": room_id,
-                    "session_id": session_id,
-                },
-                values={
-                    "version": version,
-                    "first_message_index": room_key['first_message_index'],
-                    "forwarded_count": room_key['forwarded_count'],
-                    "is_verified": room_key['is_verified'],
-                    "session_data": room_key['session_data'],
-                },
-                lock=False,
-            )
-
-            return True
-
-        return self.runInteraction(
-            "set_e2e_room_key", _set_e2e_room_key_txn
+        yield self._simple_upsert(
+            table="e2e_room_keys",
+            keyvalues={
+                "user_id": user_id,
+                "room_id": room_id,
+                "session_id": session_id,
+            },
+            values={
+                "version": version,
+                "first_message_index": room_key['first_message_index'],
+                "forwarded_count": room_key['forwarded_count'],
+                "is_verified": room_key['is_verified'],
+                "session_data": room_key['session_data'],
+            },
+            lock=False,
         )
 
     # XXX: this isn't currently used and isn't tested anywhere
@@ -107,7 +98,9 @@ class EndToEndRoomKeyStore(SQLBaseStore):
         )
 
     @defer.inlineCallbacks
-    def get_e2e_room_keys(self, user_id, version, room_id, session_id):
+    def get_e2e_room_keys(
+        self, user_id, version, room_id=room_id, session_id=session_id
+    ):
 
         keyvalues = {
             "user_id": user_id,
@@ -115,8 +108,8 @@ class EndToEndRoomKeyStore(SQLBaseStore):
         }
         if room_id:
             keyvalues['room_id'] = room_id
-        if session_id:
-            keyvalues['session_id'] = session_id
+            if session_id:
+                keyvalues['session_id'] = session_id
 
         rows = yield self._simple_select_list(
             table="e2e_room_keys",
@@ -133,18 +126,10 @@ class EndToEndRoomKeyStore(SQLBaseStore):
             desc="get_e2e_room_keys",
         )
 
-        # perlesque autovivification from https://stackoverflow.com/a/19829714/6764493
-        class AutoVivification(dict):
-            def __getitem__(self, item):
-                try:
-                    return dict.__getitem__(self, item)
-                except KeyError:
-                    value = self[item] = type(self)()
-                    return value
-
-        sessions = AutoVivification()
+        sessions = {}
         for row in rows:
-            sessions['rooms'][row['room_id']]['sessions'][row['session_id']] = {
+            room_entry = sessions['rooms'].setdefault(row['room_id'], {"sessions": {}})
+            room_entry['sessions'][row['session_id']] = {
                 "first_message_index": row["first_message_index"],
                 "forwarded_count": row["forwarded_count"],
                 "is_verified": row["is_verified"],
@@ -154,7 +139,9 @@ class EndToEndRoomKeyStore(SQLBaseStore):
         defer.returnValue(sessions)
 
     @defer.inlineCallbacks
-    def delete_e2e_room_keys(self, user_id, version, room_id, session_id):
+    def delete_e2e_room_keys(
+        self, user_id, version, room_id=room_id, session_id=session_id
+    ):
 
         keyvalues = {
             "user_id": user_id,
@@ -162,8 +149,8 @@ class EndToEndRoomKeyStore(SQLBaseStore):
         }
         if room_id:
             keyvalues['room_id'] = room_id
-        if session_id:
-            keyvalues['session_id'] = session_id
+            if session_id:
+                keyvalues['session_id'] = session_id
 
         yield self._simple_delete(
             table="e2e_room_keys",