Protect against replay attacks
1 files changed, 18 insertions, 1 deletions
diff --git a/src/Olm.cpp b/src/Olm.cpp
index 293b12de..e4ab0aa1 100644
--- a/src/Olm.cpp
+++ b/src/Olm.cpp
@@ -1028,7 +1028,8 @@ send_megolm_key_to_device(const std::string &user_id,
DecryptionResult
decryptEvent(const MegolmSessionIndex &index,
- const mtx::events::EncryptedEvent<mtx::events::msg::Encrypted> &event)
+ const mtx::events::EncryptedEvent<mtx::events::msg::Encrypted> &event,
+ bool dont_write_db)
{
try {
if (!cache::client()->inboundMegolmSessionExists(index)) {
@@ -1043,10 +1044,26 @@ decryptEvent(const MegolmSessionIndex &index,
std::string msg_str;
try {
auto session = cache::client()->getInboundMegolmSession(index);
+ auto sessionData =
+ cache::client()->getMegolmSessionData(index).value_or(GroupSessionData{});
auto res =
olm::client()->decrypt_group_message(session.get(), event.content.ciphertext);
msg_str = std::string((char *)res.data.data(), res.data.size());
+
+ if (!event.event_id.empty() && event.event_id[0] == '$') {
+ auto oldIdx = sessionData.indices.find(res.message_index);
+ if (oldIdx != sessionData.indices.end()) {
+ if (oldIdx->second != event.event_id)
+ return {DecryptionErrorCode::ReplayAttack,
+ std::nullopt,
+ std::nullopt};
+ } else if (!dont_write_db) {
+ sessionData.indices[res.message_index] = event.event_id;
+ cache::client()->saveInboundMegolmSession(
+ index, std::move(session), sessionData);
+ }
+ }
} catch (const lmdb::error &e) {
return {DecryptionErrorCode::DbError, e.what(), std::nullopt};
} catch (const mtx::crypto::olm_exception &e) {
|