summary refs log tree commit diff
path: root/src
diff options
context:
space:
mode:
authorNicolas Werner <nicolas.werner@hotmail.de>2021-08-07 23:54:35 +0200
committerNicolas Werner <nicolas.werner@hotmail.de>2021-08-07 23:54:35 +0200
commitb73bd2859ca9c3209f6da9c29346b95548b6b8c9 (patch)
tree1ebcab718c40c77b2ca546fca96e24739197a173 /src
parentShow encryption errors in qml and add request keys button (diff)
downloadnheko-b73bd2859ca9c3209f6da9c29346b95548b6b8c9.tar.xz
Protect against replay attacks
Diffstat (limited to '')
-rw-r--r--src/Cache.cpp6
-rw-r--r--src/CacheCryptoStructs.h3
-rw-r--r--src/Olm.cpp19
-rw-r--r--src/Olm.h4
4 files changed, 29 insertions, 3 deletions
diff --git a/src/Cache.cpp b/src/Cache.cpp
index 6650334a..ee991dc2 100644
--- a/src/Cache.cpp
+++ b/src/Cache.cpp
@@ -158,7 +158,7 @@ Cache::isHiddenEvent(lmdb::txn &txn,
                 index.session_id = encryptedEvent->content.session_id;
                 index.sender_key = encryptedEvent->content.sender_key;
 
-                auto result = olm::decryptEvent(index, *encryptedEvent);
+                auto result = olm::decryptEvent(index, *encryptedEvent, true);
                 if (!result.error)
                         e = result.event.value();
         }
@@ -4294,6 +4294,8 @@ to_json(nlohmann::json &obj, const GroupSessionData &msg)
         obj["forwarding_curve25519_key_chain"] = msg.forwarding_curve25519_key_chain;
 
         obj["currently"] = msg.currently;
+
+        obj["indices"] = msg.indices;
 }
 
 void
@@ -4307,6 +4309,8 @@ from_json(const nlohmann::json &obj, GroupSessionData &msg)
           obj.value("forwarding_curve25519_key_chain", std::vector<std::string>{});
 
         msg.currently = obj.value("currently", SharedWithUsers{});
+
+        msg.indices = obj.value("indices", std::map<uint32_t, std::string>());
 }
 
 void
diff --git a/src/CacheCryptoStructs.h b/src/CacheCryptoStructs.h
index 409c9d67..69d64885 100644
--- a/src/CacheCryptoStructs.h
+++ b/src/CacheCryptoStructs.h
@@ -50,6 +50,9 @@ struct GroupSessionData
         std::string sender_claimed_ed25519_key;
         std::vector<std::string> forwarding_curve25519_key_chain;
 
+        //! map from index to event_id to check for replay attacks
+        std::map<uint32_t, std::string> indices;
+
         // who has access to this session.
         // Rotate, when a user leaves the room and share, when a user gets added.
         SharedWithUsers currently;
diff --git a/src/Olm.cpp b/src/Olm.cpp
index 293b12de..e4ab0aa1 100644
--- a/src/Olm.cpp
+++ b/src/Olm.cpp
@@ -1028,7 +1028,8 @@ send_megolm_key_to_device(const std::string &user_id,
 
 DecryptionResult
 decryptEvent(const MegolmSessionIndex &index,
-             const mtx::events::EncryptedEvent<mtx::events::msg::Encrypted> &event)
+             const mtx::events::EncryptedEvent<mtx::events::msg::Encrypted> &event,
+             bool dont_write_db)
 {
         try {
                 if (!cache::client()->inboundMegolmSessionExists(index)) {
@@ -1043,10 +1044,26 @@ decryptEvent(const MegolmSessionIndex &index,
         std::string msg_str;
         try {
                 auto session = cache::client()->getInboundMegolmSession(index);
+                auto sessionData =
+                  cache::client()->getMegolmSessionData(index).value_or(GroupSessionData{});
 
                 auto res =
                   olm::client()->decrypt_group_message(session.get(), event.content.ciphertext);
                 msg_str = std::string((char *)res.data.data(), res.data.size());
+
+                if (!event.event_id.empty() && event.event_id[0] == '$') {
+                        auto oldIdx = sessionData.indices.find(res.message_index);
+                        if (oldIdx != sessionData.indices.end()) {
+                                if (oldIdx->second != event.event_id)
+                                        return {DecryptionErrorCode::ReplayAttack,
+                                                std::nullopt,
+                                                std::nullopt};
+                        } else if (!dont_write_db) {
+                                sessionData.indices[res.message_index] = event.event_id;
+                                cache::client()->saveInboundMegolmSession(
+                                  index, std::move(session), sessionData);
+                        }
+                }
         } catch (const lmdb::error &e) {
                 return {DecryptionErrorCode::DbError, e.what(), std::nullopt};
         } catch (const mtx::crypto::olm_exception &e) {
diff --git a/src/Olm.h b/src/Olm.h
index ac1a1617..ab86ca00 100644
--- a/src/Olm.h
+++ b/src/Olm.h
@@ -81,9 +81,11 @@ encrypt_group_message(const std::string &room_id,
                       const std::string &device_id,
                       nlohmann::json body);
 
+//! Decrypt an event. Use dont_write_db to prevent db writes when already in a write transaction.
 DecryptionResult
 decryptEvent(const MegolmSessionIndex &index,
-             const mtx::events::EncryptedEvent<mtx::events::msg::Encrypted> &event);
+             const mtx::events::EncryptedEvent<mtx::events::msg::Encrypted> &event,
+             bool dont_write_db = false);
 crypto::Trust
 calculate_trust(const std::string &user_id, const std::string &curve25519);