From 9102a141f3f169c39b4fe87839e646eb68fd4b55 Mon Sep 17 00:00:00 2001 From: Konstantinos Sideris Date: Fri, 15 Jun 2018 01:35:31 +0300 Subject: [PATCH] Handle OLM_MESSAGE type of messages properly --- deps/CMakeLists.txt | 2 +- include/Cache.h | 28 ++++--- include/Olm.hpp | 12 ++- src/Cache.cc | 82 ++++++++++++-------- src/Olm.cpp | 144 +++++++++++++++++++++++++---------- src/timeline/TimelineView.cc | 13 ++-- 6 files changed, 188 insertions(+), 93 deletions(-) diff --git a/deps/CMakeLists.txt b/deps/CMakeLists.txt index c948a097..7ea44bbd 100644 --- a/deps/CMakeLists.txt +++ b/deps/CMakeLists.txt @@ -40,7 +40,7 @@ set(MATRIX_STRUCTS_URL https://github.com/mujx/matrix-structs) set(MATRIX_STRUCTS_TAG eeb7373729a1618e2b3838407863342b88b8a0de) set(MTXCLIENT_URL https://github.com/mujx/mtxclient) -set(MTXCLIENT_TAG 688d5b0fd1fd16319d7fcbdbf938109eaa850545) +set(MTXCLIENT_TAG c566fa0a254dce3282435723eb58590880be2b53) set(OLM_URL https://git.matrix.org/git/olm.git) set(OLM_TAG 4065c8e11a33ba41133a086ed3de4da94dcb6bae) diff --git a/include/Cache.h b/include/Cache.h index b4dcdb90..f5a655cf 100644 --- a/include/Cache.h +++ b/include/Cache.h @@ -17,6 +17,8 @@ #pragma once +#include + #include #include @@ -209,13 +211,12 @@ struct MegolmSessionIndex struct OlmSessionStorage { - std::map outbound_sessions; + // Megolm sessions std::map group_inbound_sessions; std::map group_outbound_sessions; std::map group_outbound_session_data; - // Guards for accessing critical data. - std::mutex outbound_mtx; + // Guards for accessing megolm sessions. std::mutex group_outbound_mtx; std::mutex group_inbound_mtx; }; @@ -374,12 +375,12 @@ public: bool inboundMegolmSessionExists(const MegolmSessionIndex &index) noexcept; // - // Outbound Olm Sessions + // Olm Sessions // - void saveOutboundOlmSession(const std::string &curve25519, - mtx::crypto::OlmSessionPtr session); - OlmSession *getOutboundOlmSession(const std::string &curve25519); - bool outboundOlmSessionsExists(const std::string &curve25519) noexcept; + void saveOlmSession(const std::string &curve25519, mtx::crypto::OlmSessionPtr session); + std::vector getOlmSessions(const std::string &curve25519); + boost::optional getOlmSession(const std::string &curve25519, + const std::string &session_id); void saveOlmAccount(const std::string &pickled); std::string restoreOlmAccount(); @@ -560,6 +561,16 @@ private: return lmdb::dbi::open(txn, std::string(room_id + "/members").c_str(), MDB_CREATE); } + //! Retrieves or creates the database that stores the open OLM sessions between our device + //! and the given curve25519 key which represents another device. + //! + //! Each entry is a map from the session_id to the pickled representation of the session. + lmdb::dbi getOlmSessionsDb(lmdb::txn &txn, const std::string &curve25519_key) + { + return lmdb::dbi::open( + txn, std::string("olm_sessions/" + curve25519_key).c_str(), MDB_CREATE); + } + QString getDisplayName(const mtx::events::StateEvent &event) { if (!event.content.display_name.empty()) @@ -584,7 +595,6 @@ private: lmdb::dbi inboundMegolmSessionDb_; lmdb::dbi outboundMegolmSessionDb_; - lmdb::dbi outboundOlmSessionDb_; QString localUserId_; QString cacheDirectory_; diff --git a/include/Olm.hpp b/include/Olm.hpp index 0839f01c..6f871628 100644 --- a/include/Olm.hpp +++ b/include/Olm.hpp @@ -1,5 +1,7 @@ #pragma once +#include + #include #include #include @@ -51,13 +53,17 @@ client(); void handle_to_device_messages(const std::vector &msgs); +boost::optional +try_olm_decryption(const std::string &sender_key, const OlmCipherContent &content); + void handle_olm_message(const OlmMessage &msg); +//! Establish a new inbound megolm session with the decrypted payload from olm. void -handle_olm_normal_message(const std::string &sender, - const std::string &sender_key, - const OlmCipherContent &content); +create_inbound_megolm_session(const std::string &sender, + const std::string &sender_key, + const nlohmann::json &payload); void handle_pre_key_olm_message(const std::string &sender, diff --git a/src/Cache.cc b/src/Cache.cc index 20572ece..41a759ab 100644 --- a/src/Cache.cc +++ b/src/Cache.cc @@ -62,11 +62,10 @@ constexpr auto DEVICE_KEYS_DB("device_keys"); //! room_ids that have encryption enabled. constexpr auto ENCRYPTED_ROOMS_DB("encrypted_rooms"); -//! MegolmSessionIndex -> pickled OlmInboundGroupSession +//! room_id -> pickled OlmInboundGroupSession constexpr auto INBOUND_MEGOLM_SESSIONS_DB("inbound_megolm_sessions"); //! MegolmSessionIndex -> pickled OlmOutboundGroupSession constexpr auto OUTBOUND_MEGOLM_SESSIONS_DB("outbound_megolm_sessions"); -constexpr auto OUTBOUND_OLM_SESSIONS_DB("outbound_olm_sessions"); using CachedReceipts = std::multimap>; using Receipts = std::map>; @@ -110,7 +109,6 @@ Cache::Cache(const QString &userId, QObject *parent) , deviceKeysDb_{0} , inboundMegolmSessionDb_{0} , outboundMegolmSessionDb_{0} - , outboundOlmSessionDb_{0} , localUserId_{userId} { setup(); @@ -180,7 +178,6 @@ Cache::setup() // Session management inboundMegolmSessionDb_ = lmdb::dbi::open(txn, INBOUND_MEGOLM_SESSIONS_DB, MDB_CREATE); outboundMegolmSessionDb_ = lmdb::dbi::open(txn, OUTBOUND_MEGOLM_SESSIONS_DB, MDB_CREATE); - outboundOlmSessionDb_ = lmdb::dbi::open(txn, OUTBOUND_OLM_SESSIONS_DB, MDB_CREATE); txn.commit(); } @@ -321,35 +318,66 @@ Cache::getOutboundMegolmSession(const std::string &room_id) session_storage.group_outbound_session_data[room_id]}; } +// +// OLM sessions. +// + void -Cache::saveOutboundOlmSession(const std::string &curve25519, mtx::crypto::OlmSessionPtr session) +Cache::saveOlmSession(const std::string &curve25519, mtx::crypto::OlmSessionPtr session) { using namespace mtx::crypto; - const auto pickled = pickle(session.get(), SECRET); auto txn = lmdb::txn::begin(env_); - lmdb::dbi_put(txn, outboundOlmSessionDb_, lmdb::val(curve25519), lmdb::val(pickled)); + auto db = getOlmSessionsDb(txn, curve25519); + + const auto pickled = pickle(session.get(), SECRET); + const auto session_id = mtx::crypto::session_id(session.get()); + + lmdb::dbi_put(txn, db, lmdb::val(session_id), lmdb::val(pickled)); + + txn.commit(); +} + +boost::optional +Cache::getOlmSession(const std::string &curve25519, const std::string &session_id) +{ + using namespace mtx::crypto; + + auto txn = lmdb::txn::begin(env_); + auto db = getOlmSessionsDb(txn, curve25519); + + lmdb::val pickled; + bool found = lmdb::dbi_get(txn, db, lmdb::val(session_id), pickled); + txn.commit(); - { - std::unique_lock lock(session_storage.outbound_mtx); - session_storage.outbound_sessions[curve25519] = std::move(session); + if (found) { + auto data = std::string(pickled.data(), pickled.size()); + return unpickle(data, SECRET); } + + return boost::none; } -bool -Cache::outboundOlmSessionsExists(const std::string &curve25519) noexcept +std::vector +Cache::getOlmSessions(const std::string &curve25519) { - std::unique_lock lock(session_storage.outbound_mtx); - return session_storage.outbound_sessions.find(curve25519) != - session_storage.outbound_sessions.end(); -} + using namespace mtx::crypto; -OlmSession * -Cache::getOutboundOlmSession(const std::string &curve25519) -{ - std::unique_lock lock(session_storage.outbound_mtx); - return session_storage.outbound_sessions.at(curve25519).get(); + auto txn = lmdb::txn::begin(env_); + auto db = getOlmSessionsDb(txn, curve25519); + + std::string session_id, unused; + std::vector res; + + auto cursor = lmdb::cursor::open(txn, db); + while (cursor.get(session_id, unused, MDB_NEXT)) + res.emplace_back(session_id); + cursor.close(); + + txn.commit(); + + return res; } void @@ -405,18 +433,6 @@ Cache::restoreSessions() cursor.close(); } - // - // Outbound Olm Sessions - // - { - auto cursor = lmdb::cursor::open(txn, outboundOlmSessionDb_); - while (cursor.get(key, value, MDB_NEXT)) { - auto session = unpickle(value, SECRET); - session_storage.outbound_sessions[key] = std::move(session); - } - cursor.close(); - } - txn.commit(); nhlog::db()->info("sessions restored"); diff --git a/src/Olm.cpp b/src/Olm.cpp index f39554f0..814fce18 100644 --- a/src/Olm.cpp +++ b/src/Olm.cpp @@ -55,10 +55,21 @@ handle_olm_message(const OlmMessage &msg) const auto type = cipher.second.type; nhlog::crypto()->info("type: {}", type == 0 ? "OLM_PRE_KEY" : "OLM_MESSAGE"); - if (type == OLM_MESSAGE_TYPE_PRE_KEY) - handle_pre_key_olm_message(msg.sender, msg.sender_key, cipher.second); - else - handle_olm_normal_message(msg.sender, msg.sender_key, cipher.second); + auto payload = try_olm_decryption(msg.sender_key, cipher.second); + + if (payload) { + nhlog::crypto()->info("decrypted olm payload: {}", payload.value().dump(2)); + create_inbound_megolm_session(msg.sender, msg.sender_key, payload.value()); + return; + } + + // Not a PRE_KEY message + if (cipher.second.type != 0) { + // TODO: log that it should have matched something + return; + } + + handle_pre_key_olm_message(msg.sender, msg.sender_key, cipher.second); } } @@ -72,6 +83,10 @@ handle_pre_key_olm_message(const std::string &sender, OlmSessionPtr inbound_session = nullptr; try { inbound_session = olm::client()->create_inbound_session(content.body); + + // We also remove the one time key used to establish that + // session so we'll have to update our copy of the account object. + cache::client()->saveOlmAccount(olm::client()->save("secret")); } catch (const olm_exception &e) { nhlog::crypto()->critical( "failed to create inbound session with {}: {}", sender, e.what()); @@ -86,8 +101,8 @@ handle_pre_key_olm_message(const std::string &sender, mtx::crypto::BinaryBuf output; try { - output = olm::client()->decrypt_message( - inbound_session.get(), OLM_MESSAGE_TYPE_PRE_KEY, content.body); + output = + olm::client()->decrypt_message(inbound_session.get(), content.type, content.body); } catch (const olm_exception &e) { nhlog::crypto()->critical( "failed to decrypt olm message {}: {}", content.body, e.what()); @@ -97,45 +112,14 @@ handle_pre_key_olm_message(const std::string &sender, auto plaintext = json::parse(std::string((char *)output.data(), output.size())); nhlog::crypto()->info("decrypted message: \n {}", plaintext.dump(2)); - std::string room_id, session_id, session_key; try { - room_id = plaintext.at("content").at("room_id"); - session_id = plaintext.at("content").at("session_id"); - session_key = plaintext.at("content").at("session_key"); - } catch (const nlohmann::json::exception &e) { - nhlog::crypto()->critical( - "failed to parse plaintext olm message: {} {}", e.what(), plaintext.dump(2)); - return; + cache::client()->saveOlmSession(sender_key, std::move(inbound_session)); + } catch (const lmdb::error &e) { + nhlog::db()->warn( + "failed to save inbound olm session from {}: {}", sender, e.what()); } - MegolmSessionIndex index; - index.room_id = room_id; - index.session_id = session_id; - index.sender_key = sender_key; - - if (!cache::client()->inboundMegolmSessionExists(index)) { - auto megolm_session = olm::client()->init_inbound_group_session(session_key); - - try { - cache::client()->saveInboundMegolmSession(index, std::move(megolm_session)); - } catch (const lmdb::error &e) { - nhlog::crypto()->critical("failed to save inbound megolm session: {}", - e.what()); - return; - } - - nhlog::crypto()->info( - "established inbound megolm session ({}, {})", room_id, sender); - } else { - nhlog::crypto()->warn( - "inbound megolm session already exists ({}, {})", room_id, sender); - } -} - -void -handle_olm_normal_message(const std::string &, const std::string &, const OlmCipherContent &) -{ - nhlog::crypto()->warn("olm(1) not implemeted yet"); + create_inbound_megolm_session(sender, sender_key, plaintext); } mtx::events::msg::Encrypted @@ -165,4 +149,80 @@ encrypt_group_message(const std::string &room_id, return data; } +boost::optional +try_olm_decryption(const std::string &sender_key, const OlmCipherContent &msg) +{ + auto session_ids = cache::client()->getOlmSessions(sender_key); + + for (const auto &id : session_ids) { + auto session = cache::client()->getOlmSession(sender_key, id); + + if (!session) + continue; + + mtx::crypto::BinaryBuf text; + + try { + text = olm::client()->decrypt_message(session->get(), msg.type, msg.body); + cache::client()->saveOlmSession(id, std::move(session.value())); + + } catch (const olm_exception &e) { + nhlog::crypto()->info("failed to decrypt olm message ({}, {}) with {}: {}", + msg.type, + sender_key, + id, + e.what()); + continue; + } catch (const lmdb::error &e) { + nhlog::crypto()->critical("failed to save session: {}", e.what()); + return {}; + } + + try { + return json::parse(std::string((char *)text.data(), text.size())); + } catch (const json::exception &e) { + nhlog::crypto()->critical("failed to parse the decrypted session msg: {}", + e.what()); + } + } + + return {}; +} + +void +create_inbound_megolm_session(const std::string &sender, + const std::string &sender_key, + const nlohmann::json &payload) +{ + std::string room_id, session_id, session_key; + + try { + room_id = payload.at("content").at("room_id"); + session_id = payload.at("content").at("session_id"); + session_key = payload.at("content").at("session_key"); + } catch (const nlohmann::json::exception &e) { + nhlog::crypto()->critical( + "failed to parse plaintext olm message: {} {}", e.what(), payload.dump(2)); + return; + } + + MegolmSessionIndex index; + index.room_id = room_id; + index.session_id = session_id; + index.sender_key = sender_key; + + try { + auto megolm_session = olm::client()->init_inbound_group_session(session_key); + cache::client()->saveInboundMegolmSession(index, std::move(megolm_session)); + } catch (const lmdb::error &e) { + nhlog::crypto()->critical("failed to save inbound megolm session: {}", e.what()); + return; + } catch (const olm_exception &e) { + nhlog::crypto()->critical("failed to create inbound megolm session: {}", e.what()); + return; + } + + nhlog::crypto()->info("established inbound megolm session ({}, {})", room_id, sender); +} + } // namespace olm diff --git a/src/timeline/TimelineView.cc b/src/timeline/TimelineView.cc index 8f3ad1a7..5841ebce 100644 --- a/src/timeline/TimelineView.cc +++ b/src/timeline/TimelineView.cc @@ -1329,18 +1329,21 @@ TimelineView::prepareEncryptedMessage(const PendingMessage &msg) auto otk = rd.second.begin()->at("key"); auto id_key = pks.curve25519; - auto session = - olm::client() - ->create_outbound_session(id_key, - otk); + auto s = olm::client() + ->create_outbound_session( + id_key, otk); auto device_msg = olm::client() ->create_olm_encrypted_content( - session.get(), + s.get(), room_key, pks.curve25519); + // TODO: Handle exception + cache::client()->saveOlmSession( + id_key, std::move(s)); + json body{ {"messages", {{user_id,