summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--src/Cache.cpp23
-rw-r--r--src/Cache_p.h2
-rw-r--r--src/encryption/Olm.cpp149
3 files changed, 105 insertions, 69 deletions
diff --git a/src/Cache.cpp b/src/Cache.cpp
index e090e40d..2784cf50 100644
--- a/src/Cache.cpp
+++ b/src/Cache.cpp
@@ -913,6 +913,29 @@ Cache::getMegolmSessionData(const MegolmSessionIndex &index)
 //
 
 void
+Cache::saveOlmSessions(std::vector<std::pair<std::string, mtx::crypto::OlmSessionPtr>> sessions,
+                       uint64_t timestamp)
+{
+    using namespace mtx::crypto;
+
+    auto txn = lmdb::txn::begin(env_);
+    for (const auto &[curve25519, session] : sessions) {
+        auto db = getOlmSessionsDb(txn, curve25519);
+
+        const auto pickled    = pickle<SessionObject>(session.get(), pickle_secret_);
+        const auto session_id = mtx::crypto::session_id(session.get());
+
+        StoredOlmSession stored_session;
+        stored_session.pickled_session = pickled;
+        stored_session.last_message_ts = timestamp;
+
+        db.put(txn, session_id, nlohmann::json(stored_session).dump());
+    }
+
+    txn.commit();
+}
+
+void
 Cache::saveOlmSession(const std::string &curve25519,
                       mtx::crypto::OlmSessionPtr session,
                       uint64_t timestamp)
diff --git a/src/Cache_p.h b/src/Cache_p.h
index 1694adb7..742e4aab 100644
--- a/src/Cache_p.h
+++ b/src/Cache_p.h
@@ -277,6 +277,8 @@ public:
     void saveOlmSession(const std::string &curve25519,
                         mtx::crypto::OlmSessionPtr session,
                         uint64_t timestamp);
+    void saveOlmSessions(std::vector<std::pair<std::string, mtx::crypto::OlmSessionPtr>> sessions,
+                         uint64_t timestamp);
     std::vector<std::string> getOlmSessions(const std::string &curve25519);
     std::optional<mtx::crypto::OlmSessionPtr>
     getOlmSession(const std::string &curve25519, const std::string &session_id);
diff --git a/src/encryption/Olm.cpp b/src/encryption/Olm.cpp
index 7ada2f92..a9d5b1c2 100644
--- a/src/encryption/Olm.cpp
+++ b/src/encryption/Olm.cpp
@@ -1299,78 +1299,83 @@ send_encrypted_to_device_messages(const std::map<std::string, std::vector<std::s
 
     auto our_curve = olm::client()->identity_keys().curve25519;
 
-    for (const auto &[user, devices] : targets) {
-        auto deviceKeys = cache::client()->userKeys(user);
+    {
+        auto currentTime = QDateTime::currentSecsSinceEpoch();
+        std::vector<std::pair<std::string, mtx::crypto::OlmSessionPtr>> sessionsToPersist;
 
-        // no keys for user, query them
-        if (!deviceKeys) {
-            keysToQuery[user] = devices;
-            continue;
-        }
+        for (const auto &[user, devices] : targets) {
+            auto deviceKeys = cache::client()->userKeys(user);
 
-        auto deviceTargets = devices;
-        if (devices.empty()) {
-            deviceTargets.clear();
-            deviceTargets.reserve(deviceKeys->device_keys.size());
-            for (const auto &[device, keys] : deviceKeys->device_keys) {
-                (void)keys;
-                deviceTargets.push_back(device);
+            // no keys for user, query them
+            if (!deviceKeys) {
+                keysToQuery[user] = devices;
+                continue;
             }
-        }
 
-        for (const auto &device : deviceTargets) {
-            if (!deviceKeys->device_keys.count(device)) {
-                keysToQuery[user] = {};
-                break;
+            auto deviceTargets = devices;
+            if (devices.empty()) {
+                deviceTargets.clear();
+                deviceTargets.reserve(deviceKeys->device_keys.size());
+                for (const auto &[device, keys] : deviceKeys->device_keys) {
+                    (void)keys;
+                    deviceTargets.push_back(device);
+                }
             }
 
-            auto d = deviceKeys->device_keys.at(device);
+            for (const auto &device : deviceTargets) {
+                if (!deviceKeys->device_keys.count(device)) {
+                    keysToQuery[user] = {};
+                    break;
+                }
 
-            if (!d.keys.count("curve25519:" + device) || !d.keys.count("ed25519:" + device)) {
-                nhlog::crypto()->warn("Skipping device {} since it has no keys!", device);
-                continue;
-            }
+                const auto &d = deviceKeys->device_keys.at(device);
 
-            auto device_curve = d.keys.at("curve25519:" + device);
-            if (device_curve == our_curve) {
-                nhlog::crypto()->warn("Skipping our own device, since sending "
-                                      "ourselves olm messages makes no sense.");
-                continue;
-            }
+                if (!d.keys.count("curve25519:" + device) || !d.keys.count("ed25519:" + device)) {
+                    nhlog::crypto()->warn("Skipping device {} since it has no keys!", device);
+                    continue;
+                }
 
-            auto session = cache::getLatestOlmSession(device_curve);
-            if (!session || force_new_session) {
-                auto currentTime = QDateTime::currentSecsSinceEpoch();
-                if (rateLimit.value(QPair(user, device)) + 60 * 60 * 10 < currentTime) {
-                    claims.one_time_keys[user][device] = mtx::crypto::SIGNED_CURVE25519;
-                    pks[user][device].ed25519          = d.keys.at("ed25519:" + device);
-                    pks[user][device].curve25519       = d.keys.at("curve25519:" + device);
+                auto device_curve = d.keys.at("curve25519:" + device);
+                if (device_curve == our_curve) {
+                    nhlog::crypto()->warn("Skipping our own device, since sending "
+                                          "ourselves olm messages makes no sense.");
+                    continue;
+                }
 
-                    rateLimit.insert(QPair(user, device), currentTime);
-                } else {
-                    nhlog::crypto()->warn("Not creating new session with {}:{} "
-                                          "because of rate limit",
-                                          user,
-                                          device);
+                auto session = cache::getLatestOlmSession(device_curve);
+                if (!session || force_new_session) {
+                    if (rateLimit.value(QPair(user, device)) + 60 * 60 * 10 < currentTime) {
+                        claims.one_time_keys[user][device] = mtx::crypto::SIGNED_CURVE25519;
+                        pks[user][device].ed25519          = d.keys.at("ed25519:" + device);
+                        pks[user][device].curve25519       = d.keys.at("curve25519:" + device);
+
+                        rateLimit.insert(QPair(user, device), currentTime);
+                    } else {
+                        nhlog::crypto()->warn("Not creating new session with {}:{} "
+                                              "because of rate limit",
+                                              user,
+                                              device);
+                    }
+                    continue;
                 }
-                continue;
-            }
 
-            messages[mtx::identifiers::parse<mtx::identifiers::User>(user)][device] =
-              olm::client()
-                ->create_olm_encrypted_content(session->get(),
-                                               ev_json,
-                                               UserId(user),
-                                               d.keys.at("ed25519:" + device),
-                                               device_curve)
-                .get<mtx::events::msg::OlmEncrypted>();
+                messages[mtx::identifiers::parse<mtx::identifiers::User>(user)][device] =
+                  olm::client()
+                    ->create_olm_encrypted_content(session->get(),
+                                                   ev_json,
+                                                   UserId(user),
+                                                   d.keys.at("ed25519:" + device),
+                                                   device_curve)
+                    .get<mtx::events::msg::OlmEncrypted>();
+                sessionsToPersist.emplace_back(d.keys.at("curve25519:" + device),
+                                               std::move(*session));
+            }
+        }
 
+        if (!sessionsToPersist.empty()) {
             try {
-                nhlog::crypto()->debug("Updated olm session: {}",
-                                       mtx::crypto::session_id(session->get()));
-                cache::saveOlmSession(d.keys.at("curve25519:" + device),
-                                      std::move(*session),
-                                      QDateTime::currentMSecsSinceEpoch());
+                nhlog::crypto()->debug("Updated olm sessions: {}", sessionsToPersist.size());
+                cache::client()->saveOlmSessions(std::move(sessionsToPersist), currentTime);
             } catch (const lmdb::error &e) {
                 nhlog::db()->critical("failed to save outbound olm session: {}", e.what());
             } catch (const mtx::crypto::olm_exception &e) {
@@ -1395,6 +1400,9 @@ send_encrypted_to_device_messages(const std::map<std::string, std::vector<std::s
                                          mtx::http::RequestErr) {
             std::map<mtx::identifiers::User, std::map<std::string, mtx::events::msg::OlmEncrypted>>
               messages;
+            auto currentTime = QDateTime::currentSecsSinceEpoch();
+            std::vector<std::pair<std::string, mtx::crypto::OlmSessionPtr>> sessionsToPersist;
+
             for (const auto &[user_id, retrieved_devices] : res.one_time_keys) {
                 nhlog::net()->debug("claimed keys for {}", user_id);
                 if (retrieved_devices.size() == 0) {
@@ -1440,21 +1448,24 @@ send_encrypted_to_device_messages(const std::map<std::string, std::vector<std::s
                           session.get(), ev_json, UserId(user_id), sign_key, id_key)
                         .get<mtx::events::msg::OlmEncrypted>();
 
-                    try {
-                        nhlog::crypto()->debug("Updated olm session: {}",
-                                               mtx::crypto::session_id(session.get()));
-                        cache::saveOlmSession(
-                          id_key, std::move(session), QDateTime::currentMSecsSinceEpoch());
-                    } catch (const lmdb::error &e) {
-                        nhlog::db()->critical("failed to save outbound olm session: {}", e.what());
-                    } catch (const mtx::crypto::olm_exception &e) {
-                        nhlog::crypto()->critical("failed to pickle outbound olm session: {}",
-                                                  e.what());
-                    }
+                    sessionsToPersist.emplace_back(id_key, std::move(session));
                 }
                 nhlog::net()->info("send_to_device: {}", user_id);
             }
 
+            if (!sessionsToPersist.empty()) {
+                try {
+                    nhlog::crypto()->debug("Updated (new) olm sessions: {}",
+                                           sessionsToPersist.size());
+                    cache::client()->saveOlmSessions(std::move(sessionsToPersist), currentTime);
+                } catch (const lmdb::error &e) {
+                    nhlog::db()->critical("failed to save outbound olm session: {}", e.what());
+                } catch (const mtx::crypto::olm_exception &e) {
+                    nhlog::crypto()->critical("failed to pickle outbound olm session: {}",
+                                              e.what());
+                }
+            }
+
             if (!messages.empty())
                 http::client()->send_to_device<mtx::events::msg::OlmEncrypted>(
                   http::client()->generate_txn_id(), messages, [](mtx::http::RequestErr err) {