diff --git a/src/Cache.cpp b/src/Cache.cpp index 738f1152..c08ecad4 100644 --- a/src/Cache.cpp +++ b/src/Cache.cpp @@ -318,26 +318,46 @@ Cache::saveInboundMegolmSession(const MegolmSessionIndex &index, auto txn = lmdb::txn::begin(env_); lmdb::dbi_put(txn, inboundMegolmSessionDb_, lmdb::val(key), lmdb::val(pickled)); txn.commit(); - - { - std::unique_lock lock(session_storage.group_inbound_mtx); - session_storage.group_inbound_sessions[key] = std::move(session); - } } -OlmInboundGroupSession * +mtx::crypto::InboundGroupSessionPtr Cache::getInboundMegolmSession(const MegolmSessionIndex &index) { - std::unique_lock lock(session_storage.group_inbound_mtx); - return session_storage.group_inbound_sessions[json(index).dump()].get(); + using namespace mtx::crypto; + + try { + auto txn = lmdb::txn::begin(env_, nullptr, MDB_RDONLY); + std::string key = json(index).dump(); + lmdb::val value; + + if (lmdb::dbi_get(txn, inboundMegolmSessionDb_, lmdb::val(key), value)) { + auto session = unpickle( + std::string(value.data(), value.size()), SECRET); + return session; + } + } catch (std::exception &e) { + nhlog::db()->error("Failed to get inbound megolm session {}", e.what()); + } + + return nullptr; } bool Cache::inboundMegolmSessionExists(const MegolmSessionIndex &index) { - std::unique_lock lock(session_storage.group_inbound_mtx); - return session_storage.group_inbound_sessions.find(json(index).dump()) != - session_storage.group_inbound_sessions.end(); + using namespace mtx::crypto; + + try { + auto txn = lmdb::txn::begin(env_, nullptr, MDB_RDONLY); + std::string key = json(index).dump(); + lmdb::val value; + + return lmdb::dbi_get(txn, inboundMegolmSessionDb_, lmdb::val(key), value); + } catch (std::exception &e) { + nhlog::db()->error("Failed to get inbound megolm session {}", e.what()); + } + + return false; } void @@ -545,18 +565,6 @@ Cache::restoreSessions() auto txn = lmdb::txn::begin(env_, nullptr, MDB_RDONLY); std::string key, value; - // - // Inbound Megolm Sessions - // - { - auto cursor = lmdb::cursor::open(txn, inboundMegolmSessionDb_); - while (cursor.get(key, value, MDB_NEXT)) { - auto session = unpickle(value, SECRET); - session_storage.group_inbound_sessions[key] = std::move(session); - } - cursor.close(); - } - // // Outbound Megolm Sessions // @@ -4173,7 +4181,7 @@ saveInboundMegolmSession(const MegolmSessionIndex &index, { instance_->saveInboundMegolmSession(index, std::move(session)); } -OlmInboundGroupSession * +mtx::crypto::InboundGroupSessionPtr getInboundMegolmSession(const MegolmSessionIndex &index) { return instance_->getInboundMegolmSession(index); diff --git a/src/Cache.h b/src/Cache.h index 8cbb0006..d2af9a1b 100644 --- a/src/Cache.h +++ b/src/Cache.h @@ -256,7 +256,7 @@ exportSessionKeys(); void saveInboundMegolmSession(const MegolmSessionIndex &index, mtx::crypto::InboundGroupSessionPtr session); -OlmInboundGroupSession * +mtx::crypto::InboundGroupSessionPtr getInboundMegolmSession(const MegolmSessionIndex &index); bool inboundMegolmSessionExists(const MegolmSessionIndex &index); diff --git a/src/CacheCryptoStructs.h b/src/CacheCryptoStructs.h index 6256dcf9..80153255 100644 --- a/src/CacheCryptoStructs.h +++ b/src/CacheCryptoStructs.h @@ -55,13 +55,11 @@ from_json(const nlohmann::json &obj, MegolmSessionIndex &msg); struct OlmSessionStorage { // Megolm sessions - std::map group_inbound_sessions; std::map group_outbound_sessions; std::map group_outbound_session_data; // Guards for accessing megolm sessions. std::mutex group_outbound_mtx; - std::mutex group_inbound_mtx; }; struct StoredOlmSession diff --git a/src/Cache_p.h b/src/Cache_p.h index 9c919fb5..f9562a65 100644 --- a/src/Cache_p.h +++ b/src/Cache_p.h @@ -246,7 +246,8 @@ public: // void saveInboundMegolmSession(const MegolmSessionIndex &index, mtx::crypto::InboundGroupSessionPtr session); - OlmInboundGroupSession *getInboundMegolmSession(const MegolmSessionIndex &index); + mtx::crypto::InboundGroupSessionPtr getInboundMegolmSession( + const MegolmSessionIndex &index); bool inboundMegolmSessionExists(const MegolmSessionIndex &index); // diff --git a/src/Olm.cpp b/src/Olm.cpp index af8bb512..0f2d583f 100644 --- a/src/Olm.cpp +++ b/src/Olm.cpp @@ -534,7 +534,7 @@ handle_key_request_message(const mtx::events::DeviceEventgetInboundMegolmSession(index); - auto res = olm::client()->decrypt_group_message(session, event.content.ciphertext); - msg_str = std::string((char *)res.data.data(), res.data.size()); + auto res = + olm::client()->decrypt_group_message(session.get(), event.content.ciphertext); + msg_str = std::string((char *)res.data.data(), res.data.size()); } catch (const lmdb::error &e) { return {DecryptionErrorCode::DbError, e.what(), std::nullopt}; } catch (const mtx::crypto::olm_exception &e) { diff --git a/src/timeline/EventStore.cpp b/src/timeline/EventStore.cpp index 1cb729d3..e561d099 100644 --- a/src/timeline/EventStore.cpp +++ b/src/timeline/EventStore.cpp @@ -604,8 +604,9 @@ EventStore::decryptEvent(const IdIndex &idx, std::string msg_str; try { auto session = cache::client()->getInboundMegolmSession(index); - auto res = olm::client()->decrypt_group_message(session, e.content.ciphertext); - msg_str = std::string((char *)res.data.data(), res.data.size()); + auto res = + olm::client()->decrypt_group_message(session.get(), e.content.ciphertext); + msg_str = std::string((char *)res.data.data(), res.data.size()); } catch (const lmdb::error &e) { nhlog::db()->critical("failed to retrieve megolm session with index ({}, {}, {})", index.room_id,