summary refs log tree commit diff
path: root/src/Cache.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/Cache.cpp')
-rw-r--r--src/Cache.cpp145
1 files changed, 138 insertions, 7 deletions
diff --git a/src/Cache.cpp b/src/Cache.cpp

index 5a0740f1..d9db99b0 100644 --- a/src/Cache.cpp +++ b/src/Cache.cpp
@@ -40,7 +40,7 @@ //! Should be changed when a breaking change occurs in the cache format. //! This will reset client's data. -static const std::string CURRENT_CACHE_FORMAT_VERSION("2020.07.05"); +static const std::string CURRENT_CACHE_FORMAT_VERSION("2020.10.20"); static const std::string SECRET("secret"); static lmdb::val NEXT_BATCH_KEY("next_batch"); @@ -437,7 +437,9 @@ Cache::getOutboundMegolmSession(const std::string &room_id) // void -Cache::saveOlmSession(const std::string &curve25519, mtx::crypto::OlmSessionPtr session) +Cache::saveOlmSession(const std::string &curve25519, + mtx::crypto::OlmSessionPtr session, + uint64_t timestamp) { using namespace mtx::crypto; @@ -447,7 +449,11 @@ Cache::saveOlmSession(const std::string &curve25519, mtx::crypto::OlmSessionPtr const auto pickled = pickle<SessionObject>(session.get(), SECRET); const auto session_id = mtx::crypto::session_id(session.get()); - lmdb::dbi_put(txn, db, lmdb::val(session_id), lmdb::val(pickled)); + StoredOlmSession stored_session; + stored_session.pickled_session = pickled; + stored_session.last_message_ts = timestamp; + + lmdb::dbi_put(txn, db, lmdb::val(session_id), lmdb::val(json(stored_session).dump())); txn.commit(); } @@ -466,13 +472,44 @@ Cache::getOlmSession(const std::string &curve25519, const std::string &session_i txn.commit(); if (found) { - auto data = std::string(pickled.data(), pickled.size()); - return unpickle<SessionObject>(data, SECRET); + std::string_view raw(pickled.data(), pickled.size()); + auto data = json::parse(raw).get<StoredOlmSession>(); + return unpickle<SessionObject>(data.pickled_session, SECRET); } return std::nullopt; } +std::optional<mtx::crypto::OlmSessionPtr> +Cache::getLatestOlmSession(const std::string &curve25519) +{ + using namespace mtx::crypto; + + auto txn = lmdb::txn::begin(env_); + auto db = getOlmSessionsDb(txn, curve25519); + + std::string session_id, pickled_session; + std::vector<std::string> res; + + std::optional<StoredOlmSession> currentNewest; + + auto cursor = lmdb::cursor::open(txn, db); + while (cursor.get(session_id, pickled_session, MDB_NEXT)) { + auto data = + json::parse(std::string_view(pickled_session.data(), pickled_session.size())) + .get<StoredOlmSession>(); + if (!currentNewest || currentNewest->last_message_ts < data.last_message_ts) + currentNewest = data; + } + cursor.close(); + + txn.commit(); + + return currentNewest + ? std::optional(unpickle<SessionObject>(currentNewest->pickled_session, SECRET)) + : std::nullopt; +} + std::vector<std::string> Cache::getOlmSessions(const std::string &curve25519) { @@ -828,6 +865,80 @@ Cache::runMigrations() nhlog::db()->info("Successfully deleted pending receipts database."); return true; }}, + {"2020.10.20", + [this]() { + try { + using namespace mtx::crypto; + + auto txn = lmdb::txn::begin(env_); + + auto mainDb = lmdb::dbi::open(txn, nullptr); + + std::string dbName, ignored; + auto olmDbCursor = lmdb::cursor::open(txn, mainDb); + while (olmDbCursor.get(dbName, ignored, MDB_NEXT)) { + // skip every db but olm session dbs + nhlog::db()->debug("Db {}", dbName); + if (dbName.find("olm_sessions/") != 0) + continue; + + nhlog::db()->debug("Migrating {}", dbName); + + auto olmDb = lmdb::dbi::open(txn, dbName.c_str()); + + std::string session_id, session_value; + + std::vector<std::pair<std::string, StoredOlmSession>> sessions; + + auto cursor = lmdb::cursor::open(txn, olmDb); + while (cursor.get(session_id, session_value, MDB_NEXT)) { + nhlog::db()->debug("session_id {}, session_value {}", + session_id, + session_value); + StoredOlmSession session; + bool invalid = false; + for (auto c : session_value) + if (!isprint(c)) { + invalid = true; + break; + } + if (invalid) + continue; + + nhlog::db()->debug("Not skipped"); + + session.pickled_session = session_value; + sessions.emplace_back(session_id, session); + } + cursor.close(); + + olmDb.drop(txn, true); + + auto newDbName = dbName; + newDbName.erase(0, sizeof("olm_sessions") - 1); + newDbName = "olm_sessions.v2" + newDbName; + + auto newDb = lmdb::dbi::open(txn, newDbName.c_str(), MDB_CREATE); + + for (const auto &[key, value] : sessions) { + nhlog::db()->debug("{}\n{}", key, json(value).dump()); + lmdb::dbi_put(txn, + newDb, + lmdb::val(key), + lmdb::val(json(value).dump())); + } + } + olmDbCursor.close(); + + txn.commit(); + } catch (const lmdb::error &) { + nhlog::db()->critical("Failed to migrate olm sessions,"); + return false; + } + + nhlog::db()->info("Successfully migrated olm sessions."); + return true; + }}, }; nhlog::db()->info("Running migrations, this may take a while!"); @@ -3629,6 +3740,19 @@ from_json(const nlohmann::json &obj, MegolmSessionIndex &msg) msg.sender_key = obj.at("sender_key"); } +void +to_json(nlohmann::json &obj, const StoredOlmSession &msg) +{ + obj["ts"] = msg.last_message_ts; + obj["s"] = msg.pickled_session; +} +void +from_json(const nlohmann::json &obj, StoredOlmSession &msg) +{ + msg.last_message_ts = obj.at("ts").get<uint64_t>(); + msg.pickled_session = obj.at("s").get<std::string>(); +} + namespace cache { void init(const QString &user_id) @@ -4114,9 +4238,11 @@ inboundMegolmSessionExists(const MegolmSessionIndex &index) // Olm Sessions // void -saveOlmSession(const std::string &curve25519, mtx::crypto::OlmSessionPtr session) +saveOlmSession(const std::string &curve25519, + mtx::crypto::OlmSessionPtr session, + uint64_t timestamp) { - instance_->saveOlmSession(curve25519, std::move(session)); + instance_->saveOlmSession(curve25519, std::move(session), timestamp); } std::vector<std::string> getOlmSessions(const std::string &curve25519) @@ -4128,6 +4254,11 @@ getOlmSession(const std::string &curve25519, const std::string &session_id) { return instance_->getOlmSession(curve25519, session_id); } +std::optional<mtx::crypto::OlmSessionPtr> +getLatestOlmSession(const std::string &curve25519) +{ + return instance_->getLatestOlmSession(curve25519); +} void saveOlmAccount(const std::string &pickled)