summary refs log tree commit diff
path: root/src
diff options
context:
space:
mode:
authorNicolas Werner <nicolas.werner@hotmail.de>2023-10-23 01:40:10 +0200
committerNicolas Werner <nicolas.werner@hotmail.de>2023-10-23 01:40:10 +0200
commitd45dc6c77bc42f38048db33a5a7aceac1879c0b8 (patch)
tree74df0cb84098b950dfba2734945aea978a95a3e2 /src
parenthr tags are self closing (diff)
downloadnheko-d45dc6c77bc42f38048db33a5a7aceac1879c0b8.tar.xz
Migrate olm sessions to be stored in one database instead of thousands
Diffstat (limited to 'src')
-rw-r--r--src/Cache.cpp105
-rw-r--r--src/Cache_p.h11
-rw-r--r--src/encryption/Olm.cpp2
3 files changed, 91 insertions, 27 deletions
diff --git a/src/Cache.cpp b/src/Cache.cpp
index d975bdc5..8ad850ac 100644
--- a/src/Cache.cpp
+++ b/src/Cache.cpp
@@ -37,7 +37,7 @@
 
 //! Should be changed when a breaking change occurs in the cache format.
 //! This will reset client's data.
-static const std::string CURRENT_CACHE_FORMAT_VERSION{"2023.03.12"};
+static const std::string CURRENT_CACHE_FORMAT_VERSION{"2023.10.22"};
 
 //! Keys used for the DB
 static const std::string_view NEXT_BATCH_KEY("next_batch");
@@ -91,6 +91,8 @@ static constexpr auto INBOUND_MEGOLM_SESSIONS_DB("inbound_megolm_sessions");
 static constexpr auto OUTBOUND_MEGOLM_SESSIONS_DB("outbound_megolm_sessions");
 //! MegolmSessionIndex -> session data about which devices have access to this
 static constexpr auto MEGOLM_SESSIONS_DATA_DB("megolm_sessions_data_db");
+//! Curve25519 key to session_id and json encoded olm session, separated by null. Dupsorted.
+static constexpr auto OLM_SESSIONS_DB("olm_sessions.v3");
 
 //! flag to be set, when the db should be compacted on startup
 bool needsCompact = false;
@@ -98,6 +100,21 @@ bool needsCompact = false;
 using CachedReceipts = std::multimap<uint64_t, std::string, std::greater<uint64_t>>;
 using Receipts       = std::map<std::string, std::map<std::string, uint64_t>>;
 
+static std::string
+combineOlmSessionKeyFromCurveAndSessionId(std::string_view curve25519, std::string_view session_id)
+{
+    std::string combined(curve25519.size() + 1 + session_id.size(), '\0');
+    combined.replace(0, curve25519.size(), curve25519);
+    combined.replace(curve25519.size() + 1, session_id.size(), session_id);
+    return combined;
+}
+static std::pair<std::string_view, std::string_view>
+splitCurve25519AndOlmSessionId(std::string_view input)
+{
+    auto separator = input.find('\0');
+    return std::pair(input.substr(0, separator), input.substr(separator + 1));
+}
+
 namespace {
 std::unique_ptr<Cache> instance_ = nullptr;
 }
@@ -412,6 +429,8 @@ Cache::setup()
     outboundMegolmSessionDb_ = lmdb::dbi::open(txn, OUTBOUND_MEGOLM_SESSIONS_DB, MDB_CREATE);
     megolmSessionDataDb_     = lmdb::dbi::open(txn, MEGOLM_SESSIONS_DATA_DB, MDB_CREATE);
 
+    olmSessionDb_ = lmdb::dbi::open(txn, OLM_SESSIONS_DB, MDB_CREATE);
+
     // What rooms are encrypted
     encryptedRooms_   = lmdb::dbi::open(txn, ENCRYPTED_ROOMS_DB, MDB_CREATE);
     eventExpiryBgJob_ = lmdb::dbi::open(txn, EVENT_EXPIRATION_BG_JOB_DB, MDB_CREATE);
@@ -1075,8 +1094,6 @@ Cache::saveOlmSessions(std::vector<std::pair<std::string, mtx::crypto::OlmSessio
 
     auto txn = lmdb::txn::begin(env_);
     for (const auto &[curve25519, session] : sessions) {
-        auto db = getOlmSessionsDb(txn, curve25519);
-
         const auto pickled    = pickle<SessionObject>(session.get(), pickle_secret_);
         const auto session_id = mtx::crypto::session_id(session.get());
 
@@ -1084,7 +1101,9 @@ Cache::saveOlmSessions(std::vector<std::pair<std::string, mtx::crypto::OlmSessio
         stored_session.pickled_session = pickled;
         stored_session.last_message_ts = timestamp;
 
-        db.put(txn, session_id, nlohmann::json(stored_session).dump());
+        olmSessionDb_.put(txn,
+                          combineOlmSessionKeyFromCurveAndSessionId(curve25519, session_id),
+                          nlohmann::json(stored_session).dump());
     }
 
     txn.commit();
@@ -1098,7 +1117,6 @@ Cache::saveOlmSession(const std::string &curve25519,
     using namespace mtx::crypto;
 
     auto txn = lmdb::txn::begin(env_);
-    auto db  = getOlmSessionsDb(txn, curve25519);
 
     const auto pickled    = pickle<SessionObject>(session.get(), pickle_secret_);
     const auto session_id = mtx::crypto::session_id(session.get());
@@ -1107,7 +1125,9 @@ Cache::saveOlmSession(const std::string &curve25519,
     stored_session.pickled_session = pickled;
     stored_session.last_message_ts = timestamp;
 
-    db.put(txn, session_id, nlohmann::json(stored_session).dump());
+    olmSessionDb_.put(txn,
+                      combineOlmSessionKeyFromCurveAndSessionId(curve25519, session_id),
+                      nlohmann::json(stored_session).dump());
 
     txn.commit();
 }
@@ -1119,10 +1139,10 @@ Cache::getOlmSession(const std::string &curve25519, const std::string &session_i
 
     try {
         auto txn = ro_txn(env_);
-        auto db  = getOlmSessionsDb(txn, curve25519);
 
         std::string_view pickled;
-        bool found = db.get(txn, session_id, pickled);
+        bool found = olmSessionDb_.get(
+          txn, combineOlmSessionKeyFromCurveAndSessionId(curve25519, session_id), pickled);
 
         if (found) {
             auto data = nlohmann::json::parse(pickled).get<StoredOlmSession>();
@@ -1141,14 +1161,20 @@ Cache::getLatestOlmSession(const std::string &curve25519)
 
     try {
         auto txn = ro_txn(env_);
-        auto db  = getOlmSessionsDb(txn, curve25519);
 
-        std::string_view session_id, pickled_session;
+        std::string_view key = curve25519, pickled_session;
 
         std::optional<StoredOlmSession> currentNewest;
 
-        auto cursor = lmdb::cursor::open(txn, db);
-        while (cursor.get(session_id, pickled_session, MDB_NEXT)) {
+        auto cursor = lmdb::cursor::open(txn, olmSessionDb_);
+        bool first  = true;
+        while (cursor.get(key, pickled_session, first ? MDB_SET_RANGE : MDB_NEXT)) {
+            first = false;
+
+            auto storedCurve = splitCurve25519AndOlmSessionId(key).first;
+            if (storedCurve != curve25519)
+                break;
+
             auto data = nlohmann::json::parse(pickled_session).get<StoredOlmSession>();
             if (!currentNewest || currentNewest->last_message_ts < data.last_message_ts)
                 currentNewest = data;
@@ -1170,14 +1196,21 @@ Cache::getOlmSessions(const std::string &curve25519)
 
     try {
         auto txn = ro_txn(env_);
-        auto db  = getOlmSessionsDb(txn, curve25519);
 
-        std::string_view session_id, unused;
+        std::string_view key = curve25519, value;
         std::vector<std::string> res;
 
-        auto cursor = lmdb::cursor::open(txn, db);
-        while (cursor.get(session_id, unused, MDB_NEXT))
+        auto cursor = lmdb::cursor::open(txn, olmSessionDb_);
+
+        bool first = true;
+        while (cursor.get(key, value, first ? MDB_SET_RANGE : MDB_NEXT)) {
+            first = false;
+
+            auto [storedCurve, session_id] = splitCurve25519AndOlmSessionId(key);
+            if (storedCurve != curve25519)
+                break;
             res.emplace_back(session_id);
+        }
         cursor.close();
 
         return res;
@@ -1687,6 +1720,46 @@ Cache::runMigrations()
            nhlog::db()->info("Successfully updated states key database format.");
            return true;
        }},
+      {"2023.10.22",
+       [this]() {
+           // migrate olm sessions to a single db
+           try {
+               auto txn     = lmdb::txn::begin(env_, nullptr);
+               auto mainDb  = lmdb::dbi::open(txn);
+               auto dbNames = lmdb::cursor::open(txn, mainDb);
+
+               std::string_view dbName;
+               while (dbNames.get(dbName, MDB_NEXT)) {
+                   if (!dbName.starts_with("olm_sessions.v2/"))
+                       continue;
+
+                   auto curveKey = dbName;
+                   curveKey.remove_prefix(std::string_view("olm_sessions.v2/").size());
+
+                   auto oldDb     = lmdb::dbi::open(txn, std::string(dbName).c_str());
+                   auto olmCursor = lmdb::cursor::open(txn, oldDb);
+
+                   std::string_view session_id, json;
+                   while (olmCursor.get(session_id, json, MDB_NEXT)) {
+                       olmSessionDb_.put(
+                         txn,
+                         combineOlmSessionKeyFromCurveAndSessionId(curveKey, session_id),
+                         json);
+                   }
+
+                   oldDb.drop(txn, true);
+               }
+
+               txn.commit();
+           } catch (const lmdb::error &e) {
+               nhlog::db()->critical("Failed to convert olm sessions database in migration! {}",
+                                     e.what());
+               return false;
+           }
+
+           nhlog::db()->info("Successfully updated olm sessions database format.");
+           return true;
+       }},
     };
 
     nhlog::db()->info("Running migrations, this may take a while!");
diff --git a/src/Cache_p.h b/src/Cache_p.h
index fcfa5ff3..e59796ed 100644
--- a/src/Cache_p.h
+++ b/src/Cache_p.h
@@ -674,16 +674,6 @@ private:
         return lmdb::dbi::open(txn, "verified", MDB_CREATE);
     }
 
-    //! Retrieves or creates the database that stores the open OLM sessions between our device
-    //! and the given curve25519 key which represents another device.
-    //!
-    //! Each entry is a map from the session_id to the pickled representation of the session.
-    lmdb::dbi getOlmSessionsDb(lmdb::txn &txn, const std::string &curve25519_key)
-    {
-        return lmdb::dbi::open(
-          txn, std::string("olm_sessions.v2/" + curve25519_key).c_str(), MDB_CREATE);
-    }
-
     QString getDisplayName(const mtx::events::StateEvent<mtx::events::state::Member> &event)
     {
         if (!event.content.display_name.empty())
@@ -713,6 +703,7 @@ private:
     lmdb::dbi inboundMegolmSessionDb_;
     lmdb::dbi outboundMegolmSessionDb_;
     lmdb::dbi megolmSessionDataDb_;
+    lmdb::dbi olmSessionDb_;
 
     lmdb::dbi encryptedRooms_;
 
diff --git a/src/encryption/Olm.cpp b/src/encryption/Olm.cpp
index 8993f715..7fa176b0 100644
--- a/src/encryption/Olm.cpp
+++ b/src/encryption/Olm.cpp
@@ -719,7 +719,7 @@ try_olm_decryption(const std::string &sender_key, const mtx::events::msg::OlmCip
             nhlog::crypto()->debug("Updated olm session: {}",
                                    mtx::crypto::session_id(session->get()));
             cache::saveOlmSession(
-              id, std::move(session.value()), QDateTime::currentMSecsSinceEpoch());
+              sender_key, std::move(session.value()), QDateTime::currentMSecsSinceEpoch());
         } catch (const mtx::crypto::olm_exception &e) {
             nhlog::crypto()->debug("failed to decrypt olm message ({}, {}) with {}: {}",
                                    msg.type,