summary refs log tree commit diff
path: root/src/Cache.cpp
diff options
context:
space:
mode:
authorNicolas Werner <nicolas.werner@hotmail.de>2021-07-17 01:27:37 +0200
committerNicolas Werner <nicolas.werner@hotmail.de>2021-07-17 01:27:37 +0200
commit9fadd148715790743cb4e87bfe1854923e59c06b (patch)
tree2f3db93039106328338bd67eac8666b14b5fdaab /src/Cache.cpp
parentFix replies not reloading after fetching them (diff)
downloadnheko-9fadd148715790743cb4e87bfe1854923e59c06b.tar.xz
Store megolm session data in separate database
Diffstat (limited to 'src/Cache.cpp')
-rw-r--r--src/Cache.cpp102
1 files changed, 72 insertions, 30 deletions
diff --git a/src/Cache.cpp b/src/Cache.cpp
index 9304db0e..1c156104 100644
--- a/src/Cache.cpp
+++ b/src/Cache.cpp
@@ -78,6 +78,8 @@ constexpr auto ENCRYPTED_ROOMS_DB("encrypted_rooms");
 constexpr auto INBOUND_MEGOLM_SESSIONS_DB("inbound_megolm_sessions");
 //! MegolmSessionIndex -> pickled OlmOutboundGroupSession
 constexpr auto OUTBOUND_MEGOLM_SESSIONS_DB("outbound_megolm_sessions");
+//! MegolmSessionIndex -> session data about which devices have access to this
+constexpr auto MEGOLM_SESSIONS_DATA_DB("megolm_sessions_data_db");
 
 using CachedReceipts = std::multimap<uint64_t, std::string, std::greater<uint64_t>>;
 using Receipts       = std::map<std::string, std::map<std::string, uint64_t>>;
@@ -284,6 +286,7 @@ Cache::setup()
         // Session management
         inboundMegolmSessionDb_  = lmdb::dbi::open(txn, INBOUND_MEGOLM_SESSIONS_DB, MDB_CREATE);
         outboundMegolmSessionDb_ = lmdb::dbi::open(txn, OUTBOUND_MEGOLM_SESSIONS_DB, MDB_CREATE);
+        megolmSessionDataDb_     = lmdb::dbi::open(txn, MEGOLM_SESSIONS_DATA_DB, MDB_CREATE);
 
         txn.commit();
 
@@ -387,9 +390,14 @@ Cache::importSessionKeys(const mtx::crypto::ExportedSessionKeys &keys)
                 index.session_id = s.session_id;
                 index.sender_key = s.sender_key;
 
+                GroupSessionData data{};
+                data.forwarding_curve25519_key_chain = s.forwarding_curve25519_key_chain;
+                if (s.sender_claimed_keys.count("ed25519"))
+                        data.sender_claimed_ed25519_key = s.sender_claimed_keys.at("ed25519");
+
                 auto exported_session = mtx::crypto::import_session(s.session_key);
 
-                saveInboundMegolmSession(index, std::move(exported_session));
+                saveInboundMegolmSession(index, std::move(exported_session), data);
                 ChatPage::instance()->receivedSessionKey(index.room_id, index.session_id);
         }
 }
@@ -400,7 +408,8 @@ Cache::importSessionKeys(const mtx::crypto::ExportedSessionKeys &keys)
 
 void
 Cache::saveInboundMegolmSession(const MegolmSessionIndex &index,
-                                mtx::crypto::InboundGroupSessionPtr session)
+                                mtx::crypto::InboundGroupSessionPtr session,
+                                const GroupSessionData &data)
 {
         using namespace mtx::crypto;
         const auto key     = json(index).dump();
@@ -420,6 +429,7 @@ Cache::saveInboundMegolmSession(const MegolmSessionIndex &index,
         }
 
         inboundMegolmSessionDb_.put(txn, key, pickled);
+        megolmSessionDataDb_.put(txn, key, json(data).dump());
         txn.commit();
 }
 
@@ -464,7 +474,7 @@ Cache::inboundMegolmSessionExists(const MegolmSessionIndex &index)
 
 void
 Cache::updateOutboundMegolmSession(const std::string &room_id,
-                                   const OutboundGroupSessionData &data_,
+                                   const GroupSessionData &data_,
                                    mtx::crypto::OutboundGroupSessionPtr &ptr)
 {
         using namespace mtx::crypto;
@@ -472,18 +482,20 @@ Cache::updateOutboundMegolmSession(const std::string &room_id,
         if (!outboundMegolmSessionExists(room_id))
                 return;
 
-        OutboundGroupSessionData data = data_;
-        data.message_index            = olm_outbound_group_session_message_index(ptr.get());
-        data.session_id               = mtx::crypto::session_id(ptr.get());
-        data.session_key              = mtx::crypto::session_key(ptr.get());
+        GroupSessionData data = data_;
+        data.message_index    = olm_outbound_group_session_message_index(ptr.get());
+        MegolmSessionIndex index;
+        index.room_id    = room_id;
+        index.sender_key = olm::client()->identity_keys().ed25519;
+        index.session_id = mtx::crypto::session_id(ptr.get());
 
         // Save the updated pickled data for the session.
         json j;
-        j["data"]    = data;
         j["session"] = pickle<OutboundSessionObject>(ptr.get(), SECRET);
 
         auto txn = lmdb::txn::begin(env_);
         outboundMegolmSessionDb_.put(txn, room_id, j.dump());
+        megolmSessionDataDb_.put(txn, json(index).dump(), json(data).dump());
         txn.commit();
 }
 
@@ -498,24 +510,32 @@ Cache::dropOutboundMegolmSession(const std::string &room_id)
         {
                 auto txn = lmdb::txn::begin(env_);
                 outboundMegolmSessionDb_.del(txn, room_id);
+                // don't delete session data, so that we can still share the session.
                 txn.commit();
         }
 }
 
 void
 Cache::saveOutboundMegolmSession(const std::string &room_id,
-                                 const OutboundGroupSessionData &data,
+                                 const GroupSessionData &data_,
                                  mtx::crypto::OutboundGroupSessionPtr &session)
 {
         using namespace mtx::crypto;
         const auto pickled = pickle<OutboundSessionObject>(session.get(), SECRET);
 
+        GroupSessionData data = data_;
+        data.message_index    = olm_outbound_group_session_message_index(session.get());
+        MegolmSessionIndex index;
+        index.room_id    = room_id;
+        index.sender_key = olm::client()->identity_keys().ed25519;
+        index.session_id = mtx::crypto::session_id(session.get());
+
         json j;
-        j["data"]    = data;
         j["session"] = pickled;
 
         auto txn = lmdb::txn::begin(env_);
         outboundMegolmSessionDb_.put(txn, room_id, j.dump());
+        megolmSessionDataDb_.put(txn, json(index).dump(), json(data).dump());
         txn.commit();
 }
 
@@ -544,8 +564,17 @@ Cache::getOutboundMegolmSession(const std::string &room_id)
                 auto obj = json::parse(value);
 
                 OutboundGroupSessionDataRef ref{};
-                ref.data    = obj.at("data").get<OutboundGroupSessionData>();
                 ref.session = unpickle<OutboundSessionObject>(obj.at("session"), SECRET);
+
+                MegolmSessionIndex index;
+                index.room_id    = room_id;
+                index.sender_key = olm::client()->identity_keys().ed25519;
+                index.session_id = mtx::crypto::session_id(ref.session.get());
+
+                if (megolmSessionDataDb_.get(txn, json(index).dump(), value)) {
+                        ref.data = nlohmann::json::parse(value).get<GroupSessionData>();
+                }
+
                 return ref;
         } catch (std::exception &e) {
                 nhlog::db()->error("Failed to retrieve outbound Megolm Session: {}", e.what());
@@ -829,6 +858,7 @@ Cache::deleteData()
 
         lmdb::dbi_close(env_, inboundMegolmSessionDb_);
         lmdb::dbi_close(env_, outboundMegolmSessionDb_);
+        lmdb::dbi_close(env_, megolmSessionDataDb_);
 
         env_.close();
 
@@ -3525,6 +3555,7 @@ to_json(json &j, const UserKeyCache &info)
 {
         j["device_keys"]        = info.device_keys;
         j["seen_device_keys"]   = info.seen_device_keys;
+        j["seen_device_ids"]    = info.seen_device_ids;
         j["master_keys"]        = info.master_keys;
         j["master_key_changed"] = info.master_key_changed;
         j["user_signing_keys"]  = info.user_signing_keys;
@@ -3538,6 +3569,7 @@ from_json(const json &j, UserKeyCache &info)
 {
         info.device_keys = j.value("device_keys", std::map<std::string, mtx::crypto::DeviceKeys>{});
         info.seen_device_keys   = j.value("seen_device_keys", std::set<std::string>{});
+        info.seen_device_ids    = j.value("seen_device_ids", std::set<std::string>{});
         info.master_keys        = j.value("master_keys", mtx::crypto::CrossSigningKeys{});
         info.master_key_changed = j.value("master_key_changed", false);
         info.user_signing_keys  = j.value("user_signing_keys", mtx::crypto::CrossSigningKeys{});
@@ -3634,6 +3666,15 @@ Cache::updateUserKeys(const std::string &sync_token, const mtx::responses::Query
                                                         keyReused = true;
                                                         break;
                                                 }
+                                                if (updateToWrite.seen_device_ids.count(
+                                                      device_id)) {
+                                                        nhlog::crypto()->warn(
+                                                          "device_id '{}' reused by ({})",
+                                                          device_id,
+                                                          user);
+                                                        keyReused = true;
+                                                        break;
+                                                }
                                         }
 
                                         if (!keyReused && !oldDeviceKeys.count(device_id))
@@ -3644,6 +3685,7 @@ Cache::updateUserKeys(const std::string &sync_token, const mtx::responses::Query
                                         (void)key_id;
                                         updateToWrite.seen_device_keys.insert(key);
                                 }
+                                updateToWrite.seen_device_ids.insert(device_id);
                         }
                 }
                 db.put(txn, user, json(updateToWrite).dump());
@@ -4077,17 +4119,15 @@ from_json(const json &j, MemberInfo &info)
 }
 
 void
-to_json(nlohmann::json &obj, const DeviceAndMasterKeys &msg)
+to_json(nlohmann::json &obj, const DeviceKeysToMsgIndex &msg)
 {
-        obj["devices"]     = msg.devices;
-        obj["master_keys"] = msg.master_keys;
+        obj["deviceids"] = msg.deviceids;
 }
 
 void
-from_json(const nlohmann::json &obj, DeviceAndMasterKeys &msg)
+from_json(const nlohmann::json &obj, DeviceKeysToMsgIndex &msg)
 {
-        msg.devices     = obj.at("devices").get<decltype(msg.devices)>();
-        msg.master_keys = obj.at("master_keys").get<decltype(msg.master_keys)>();
+        msg.deviceids = obj.at("deviceids").get<decltype(msg.deviceids)>();
 }
 
 void
@@ -4099,30 +4139,31 @@ to_json(nlohmann::json &obj, const SharedWithUsers &msg)
 void
 from_json(const nlohmann::json &obj, SharedWithUsers &msg)
 {
-        msg.keys = obj.at("keys").get<std::map<std::string, DeviceAndMasterKeys>>();
+        msg.keys = obj.at("keys").get<std::map<std::string, DeviceKeysToMsgIndex>>();
 }
 
 void
-to_json(nlohmann::json &obj, const OutboundGroupSessionData &msg)
+to_json(nlohmann::json &obj, const GroupSessionData &msg)
 {
-        obj["session_id"]    = msg.session_id;
-        obj["session_key"]   = msg.session_key;
         obj["message_index"] = msg.message_index;
         obj["ts"]            = msg.timestamp;
 
-        obj["initially"] = msg.initially;
+        obj["sender_claimed_ed25519_key"]      = msg.sender_claimed_ed25519_key;
+        obj["forwarding_curve25519_key_chain"] = msg.forwarding_curve25519_key_chain;
+
         obj["currently"] = msg.currently;
 }
 
 void
-from_json(const nlohmann::json &obj, OutboundGroupSessionData &msg)
+from_json(const nlohmann::json &obj, GroupSessionData &msg)
 {
-        msg.session_id    = obj.at("session_id");
-        msg.session_key   = obj.at("session_key");
         msg.message_index = obj.at("message_index");
         msg.timestamp     = obj.value("ts", 0ULL);
 
-        msg.initially = obj.value("initially", SharedWithUsers{});
+        msg.sender_claimed_ed25519_key = obj.value("sender_claimed_ed25519_key", "");
+        msg.forwarding_curve25519_key_chain =
+          obj.value("forwarding_curve25519_key_chain", std::vector<std::string>{});
+
         msg.currently = obj.value("currently", SharedWithUsers{});
 }
 
@@ -4522,7 +4563,7 @@ isRoomMember(const std::string &user_id, const std::string &room_id)
 //
 void
 saveOutboundMegolmSession(const std::string &room_id,
-                          const OutboundGroupSessionData &data,
+                          const GroupSessionData &data,
                           mtx::crypto::OutboundGroupSessionPtr &session)
 {
         instance_->saveOutboundMegolmSession(room_id, data, session);
@@ -4539,7 +4580,7 @@ outboundMegolmSessionExists(const std::string &room_id) noexcept
 }
 void
 updateOutboundMegolmSession(const std::string &room_id,
-                            const OutboundGroupSessionData &data,
+                            const GroupSessionData &data,
                             mtx::crypto::OutboundGroupSessionPtr &session)
 {
         instance_->updateOutboundMegolmSession(room_id, data, session);
@@ -4566,9 +4607,10 @@ exportSessionKeys()
 //
 void
 saveInboundMegolmSession(const MegolmSessionIndex &index,
-                         mtx::crypto::InboundGroupSessionPtr session)
+                         mtx::crypto::InboundGroupSessionPtr session,
+                         const GroupSessionData &data)
 {
-        instance_->saveInboundMegolmSession(index, std::move(session));
+        instance_->saveInboundMegolmSession(index, std::move(session), data);
 }
 mtx::crypto::InboundGroupSessionPtr
 getInboundMegolmSession(const MegolmSessionIndex &index)