From d5f998c50d3188fc9eeb94ce80e4d4dfd15d6790 Mon Sep 17 00:00:00 2001 From: John Harrison Date: Thu, 7 Aug 2025 08:56:11 -0700 Subject: [PATCH 1/5] [lldb] Refactoring JSONTransport into an abstract RPC Message Handler and transport layer. This abstracts the base Transport handler to have a MessageHandler component and allows us to generalize both JSON-RPC 2.0 for MCP (or an LSP) and DAP format. This should allow us to create clearly defined clients and servers for protocols, both for testing and for RPC between the lldb instances and an lldb-mcp multiplexer. This basic model is inspiried by the clangd/Transport.h file and the mlir/lsp-server-support/Transport.h that are both used for LSP servers within the llvm project. --- lldb/include/lldb/Host/JSONTransport.h | 324 ++++++++++----- lldb/source/Host/common/JSONTransport.cpp | 116 +----- lldb/source/Protocol/MCP/Protocol.cpp | 1 + lldb/tools/lldb-dap/DAP.cpp | 177 ++++---- lldb/tools/lldb-dap/DAP.h | 25 +- lldb/tools/lldb-dap/Protocol/ProtocolBase.h | 4 + lldb/tools/lldb-dap/Transport.cpp | 5 +- lldb/tools/lldb-dap/Transport.h | 5 +- lldb/tools/lldb-dap/tool/lldb-dap.cpp | 21 +- lldb/unittests/DAP/DAPTest.cpp | 16 +- lldb/unittests/DAP/Handler/DisconnectTest.cpp | 20 +- lldb/unittests/DAP/TestBase.cpp | 48 +-- lldb/unittests/DAP/TestBase.h | 91 +++-- lldb/unittests/Host/JSONTransportTest.cpp | 382 +++++++++++------- .../ProtocolServer/ProtocolMCPServerTest.cpp | 174 ++++---- 15 files changed, 765 insertions(+), 644 deletions(-) diff --git a/lldb/include/lldb/Host/JSONTransport.h b/lldb/include/lldb/Host/JSONTransport.h index 72f4404c92887..18126f599c380 100644 --- a/lldb/include/lldb/Host/JSONTransport.h +++ b/lldb/include/lldb/Host/JSONTransport.h @@ -13,29 +13,25 @@ #ifndef LLDB_HOST_JSONTRANSPORT_H #define LLDB_HOST_JSONTRANSPORT_H +#include "lldb/Host/MainLoop.h" #include "lldb/Host/MainLoopBase.h" #include "lldb/Utility/IOObject.h" #include "lldb/Utility/Status.h" #include "lldb/lldb-forward.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Error.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/JSON.h" +#include "llvm/Support/raw_ostream.h" #include #include +#include #include namespace lldb_private { -class TransportEOFError : public llvm::ErrorInfo { -public: - static char ID; - - TransportEOFError() = default; - void log(llvm::raw_ostream &OS) const override; - std::error_code convertToErrorCode() const override; -}; - class TransportUnhandledContentsError : public llvm::ErrorInfo { public: @@ -54,112 +50,220 @@ class TransportUnhandledContentsError std::string m_unhandled_contents; }; -class TransportInvalidError : public llvm::ErrorInfo { +/// A transport is responsible for maintaining the connection to a client +/// application, and reading/writing structured messages to it. +/// +/// Transports have limited thread safety requirements: +/// - Messages will not be sent concurrently. +/// - Messages MAY be sent while Run() is reading, or its callback is active. +template class Transport { public: - static char ID; - - TransportInvalidError() = default; + using Message = std::variant; + + virtual ~Transport() = default; + + // Called by transport to send outgoing messages. + virtual void Event(const Evt &) = 0; + virtual void Request(const Req &) = 0; + virtual void Response(const Resp &) = 0; + + /// Implemented to handle incoming messages. (See Run() below). + class MessageHandler { + public: + virtual ~MessageHandler() = default; + virtual void OnEvent(const Evt &) = 0; + virtual void OnRequest(const Req &) = 0; + virtual void OnResponse(const Resp &) = 0; + }; + + /// Called by server or client to receive messages from the connection. + /// The transport should in turn invoke the handler to process messages. + /// The MainLoop is used to handle reading from the incoming connection and + /// will run until the loop is terminated. + virtual llvm::Error Run(MainLoop &, MessageHandler &) = 0; - void log(llvm::raw_ostream &OS) const override; - std::error_code convertToErrorCode() const override; +protected: + template inline auto Logv(const char *Fmt, Ts &&...Vals) { + Log(llvm::formatv(Fmt, std::forward(Vals)...).str()); + } + virtual void Log(llvm::StringRef message) = 0; }; -/// A transport class that uses JSON for communication. -class JSONTransport { +/// A JSONTransport will encode and decode messages using JSON. +template +class JSONTransport : public Transport { public: - using ReadHandleUP = MainLoopBase::ReadHandleUP; - template - using Callback = std::function)>; - - JSONTransport(lldb::IOObjectSP input, lldb::IOObjectSP output); - virtual ~JSONTransport() = default; - - /// Transport is not copyable. - /// @{ - JSONTransport(const JSONTransport &rhs) = delete; - void operator=(const JSONTransport &rhs) = delete; - /// @} - - /// Writes a message to the output stream. - template llvm::Error Write(const T &t) { - const std::string message = llvm::formatv("{0}", toJSON(t)).str(); - return WriteImpl(message); + using Transport::Transport; + + JSONTransport(lldb::IOObjectSP in, lldb::IOObjectSP out) + : m_in(in), m_out(out) {} + + void Event(const Evt &evt) override { Write(evt); } + void Request(const Req &req) override { Write(req); } + void Response(const Resp &resp) override { Write(resp); } + + /// Run registers the transport with the given MainLoop and handles any + /// incoming messages using the given MessageHandler. + llvm::Error + Run(MainLoop &loop, + typename Transport::MessageHandler &handler) override { + llvm::Error error = llvm::Error::success(); + Status status; + auto read_handle = loop.RegisterReadObject( + m_in, + std::bind(&JSONTransport::OnRead, this, &error, std::placeholders::_1, + std::ref(handler)), + status); + if (status.Fail()) { + // This error is only set if the read object handler is invoked, mark it + // as consumed if registration of the handler failed. + llvm::consumeError(std::move(error)); + return status.takeError(); + } + + status = loop.Run(); + if (status.Fail()) + return status.takeError(); + return error; } - /// Registers the transport with the MainLoop. - template - llvm::Expected RegisterReadObject(MainLoopBase &loop, - Callback read_cb) { - Status error; - ReadHandleUP handle = loop.RegisterReadObject( - m_input, - [read_cb, this](MainLoopBase &loop) { - char buf[kReadBufferSize]; - size_t num_bytes = sizeof(buf); - if (llvm::Error error = m_input->Read(buf, num_bytes).takeError()) { - read_cb(loop, std::move(error)); - return; - } - if (num_bytes) - m_buffer.append(std::string(buf, num_bytes)); - - // If the buffer has contents, try parsing any pending messages. - if (!m_buffer.empty()) { - llvm::Expected> messages = Parse(); - if (llvm::Error error = messages.takeError()) { - read_cb(loop, std::move(error)); - return; - } - - for (const auto &message : *messages) - if constexpr (std::is_same::value) - read_cb(loop, message); - else - read_cb(loop, llvm::json::parse(message)); - } - - // On EOF, notify the callback after the remaining messages were - // handled. - if (num_bytes == 0) { - if (m_buffer.empty()) - read_cb(loop, llvm::make_error()); - else - read_cb(loop, llvm::make_error( - std::string(m_buffer))); - } - }, - error); - if (error.Fail()) - return error.takeError(); - return handle; - } + /// Public for testing purposes, otherwise this should be an implementation + /// detail. + static constexpr size_t kReadBufferSize = 1024; protected: - template inline auto Logv(const char *Fmt, Ts &&...Vals) { - Log(llvm::formatv(Fmt, std::forward(Vals)...).str()); + virtual llvm::Expected> Parse() = 0; + virtual std::string Encode(const llvm::json::Value &message) = 0; + void Write(const llvm::json::Value &message) { + this->Logv("<-- {0}", message); + std::string output = Encode(message); + size_t bytes_written = output.size(); + Status status = m_out->Write(output.data(), bytes_written); + if (status.Fail()) { + this->Logv("writing failed: s{0}", status.AsCString()); + } } - virtual void Log(llvm::StringRef message); - virtual llvm::Error WriteImpl(const std::string &message) = 0; - virtual llvm::Expected> Parse() = 0; + llvm::SmallString m_buffer; - static constexpr size_t kReadBufferSize = 1024; +private: + void OnRead(llvm::Error *err, MainLoopBase &loop, + typename Transport::MessageHandler &handler) { + llvm::ErrorAsOutParameter ErrAsOutParam(err); + char buf[kReadBufferSize]; + size_t num_bytes = sizeof(buf); + if (Status status = m_in->Read(buf, num_bytes); status.Fail()) { + *err = status.takeError(); + loop.RequestTermination(); + return; + } + + if (num_bytes) + m_buffer.append(llvm::StringRef(buf, num_bytes)); + + // If the buffer has contents, try parsing any pending messages. + if (!m_buffer.empty()) { + llvm::Expected> raw_messages = Parse(); + if (llvm::Error error = raw_messages.takeError()) { + *err = std::move(error); + loop.RequestTermination(); + return; + } + + for (const auto &raw_message : *raw_messages) { + auto message = + llvm::json::parse::Message>( + raw_message); + if (!message) { + *err = message.takeError(); + loop.RequestTermination(); + return; + } + + if (Evt *evt = std::get_if(&*message)) { + handler.OnEvent(*evt); + } else if (Req *req = std::get_if(&*message)) { + handler.OnRequest(*req); + } else if (Resp *resp = std::get_if(&*message)) { + handler.OnResponse(*resp); + } else { + llvm_unreachable("unknown message type"); + } + } + } + + if (num_bytes == 0) { + // If we're at EOF and we have unhandled contents in the buffer, return an + // error for the partial message. + if (m_buffer.empty()) + *err = llvm::Error::success(); + else + *err = llvm::make_error( + std::string(m_buffer)); + loop.RequestTermination(); + } + } - lldb::IOObjectSP m_input; - lldb::IOObjectSP m_output; - llvm::SmallString m_buffer; + lldb::IOObjectSP m_in; + lldb::IOObjectSP m_out; }; /// A transport class for JSON with a HTTP header. -class HTTPDelimitedJSONTransport : public JSONTransport { +template +class HTTPDelimitedJSONTransport : public JSONTransport { public: - HTTPDelimitedJSONTransport(lldb::IOObjectSP input, lldb::IOObjectSP output) - : JSONTransport(input, output) {} - virtual ~HTTPDelimitedJSONTransport() = default; + using JSONTransport::JSONTransport; protected: - llvm::Error WriteImpl(const std::string &message) override; - llvm::Expected> Parse() override; + /// Encodes messages based on + /// https://microsoft.github.io/debug-adapter-protocol/overview#base-protocol + std::string Encode(const llvm::json::Value &message) override { + std::string output; + std::string raw_message = llvm::formatv("{0}", message).str(); + llvm::raw_string_ostream OS(output); + OS << kHeaderContentLength << kHeaderFieldSeparator << ' ' + << std::to_string(raw_message.size()) << kEndOfHeader << raw_message; + return output; + } + + /// Parses messages based on + /// https://microsoft.github.io/debug-adapter-protocol/overview#base-protocol + llvm::Expected> Parse() override { + std::vector messages; + llvm::StringRef buffer = this->m_buffer; + while (buffer.contains(kEndOfHeader)) { + auto [headers, rest] = buffer.split(kEndOfHeader); + size_t content_length = 0; + // HTTP Headers are formatted like ` ':' []`. + for (const auto &header : llvm::split(headers, kHeaderSeparator)) { + auto [key, value] = header.split(kHeaderFieldSeparator); + // 'Content-Length' is the only meaningful key at the moment. Others are + // ignored. + if (!key.equals_insensitive(kHeaderContentLength)) + continue; + + value = value.trim(); + if (!llvm::to_integer(value, content_length, 10)) + return llvm::createStringError(std::errc::invalid_argument, + "invalid content length: %s", + value.str().c_str()); + } + + // Check if we have enough data. + if (content_length > rest.size()) + break; + + llvm::StringRef body = rest.take_front(content_length); + buffer = rest.drop_front(content_length); + messages.emplace_back(body.str()); + this->Logv("--> {0}", body); + } + + // Store the remainder of the buffer for the next read callback. + this->m_buffer = buffer.str(); + + return std::move(messages); + } static constexpr llvm::StringLiteral kHeaderContentLength = "Content-Length"; static constexpr llvm::StringLiteral kHeaderFieldSeparator = ":"; @@ -168,15 +272,31 @@ class HTTPDelimitedJSONTransport : public JSONTransport { }; /// A transport class for JSON RPC. -class JSONRPCTransport : public JSONTransport { +template +class JSONRPCTransport : public JSONTransport { public: - JSONRPCTransport(lldb::IOObjectSP input, lldb::IOObjectSP output) - : JSONTransport(input, output) {} - virtual ~JSONRPCTransport() = default; + using JSONTransport::JSONTransport; protected: - llvm::Error WriteImpl(const std::string &message) override; - llvm::Expected> Parse() override; + std::string Encode(const llvm::json::Value &message) override { + return llvm::formatv("{0}{1}", message, kMessageSeparator).str(); + } + + llvm::Expected> Parse() override { + std::vector messages; + llvm::StringRef buf = this->m_buffer; + while (buf.contains(kMessageSeparator)) { + auto [raw_json, rest] = buf.split(kMessageSeparator); + buf = rest; + messages.emplace_back(raw_json.str()); + this->Logv("--> {0}", raw_json); + } + + // Store the remainder of the buffer for the next read callback. + this->m_buffer = buf.str(); + + return messages; + } static constexpr llvm::StringLiteral kMessageSeparator = "\n"; }; diff --git a/lldb/source/Host/common/JSONTransport.cpp b/lldb/source/Host/common/JSONTransport.cpp index 5f0fb3ce562c3..c4b42eafc85d3 100644 --- a/lldb/source/Host/common/JSONTransport.cpp +++ b/lldb/source/Host/common/JSONTransport.cpp @@ -7,136 +7,26 @@ //===----------------------------------------------------------------------===// #include "lldb/Host/JSONTransport.h" -#include "lldb/Utility/LLDBLog.h" #include "lldb/Utility/Log.h" #include "lldb/Utility/Status.h" -#include "lldb/lldb-forward.h" #include "llvm/ADT/StringExtras.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/Support/Error.h" #include "llvm/Support/raw_ostream.h" #include -#include using namespace llvm; using namespace lldb; using namespace lldb_private; -void TransportEOFError::log(llvm::raw_ostream &OS) const { - OS << "transport EOF"; -} - -std::error_code TransportEOFError::convertToErrorCode() const { - return std::make_error_code(std::errc::io_error); -} +char TransportUnhandledContentsError::ID; TransportUnhandledContentsError::TransportUnhandledContentsError( std::string unhandled_contents) : m_unhandled_contents(unhandled_contents) {} void TransportUnhandledContentsError::log(llvm::raw_ostream &OS) const { - OS << "transport EOF with unhandled contents " << m_unhandled_contents; + OS << "transport EOF with unhandled contents: '" << m_unhandled_contents + << "'"; } std::error_code TransportUnhandledContentsError::convertToErrorCode() const { return std::make_error_code(std::errc::bad_message); } - -void TransportInvalidError::log(llvm::raw_ostream &OS) const { - OS << "transport IO object invalid"; -} -std::error_code TransportInvalidError::convertToErrorCode() const { - return std::make_error_code(std::errc::not_connected); -} - -JSONTransport::JSONTransport(IOObjectSP input, IOObjectSP output) - : m_input(std::move(input)), m_output(std::move(output)) {} - -void JSONTransport::Log(llvm::StringRef message) { - LLDB_LOG(GetLog(LLDBLog::Host), "{0}", message); -} - -// Parses messages based on -// https://microsoft.github.io/debug-adapter-protocol/overview#base-protocol -Expected> HTTPDelimitedJSONTransport::Parse() { - std::vector messages; - StringRef buffer = m_buffer; - while (buffer.contains(kEndOfHeader)) { - auto [headers, rest] = buffer.split(kEndOfHeader); - size_t content_length = 0; - // HTTP Headers are formatted like ` ':' []`. - for (const auto &header : llvm::split(headers, kHeaderSeparator)) { - auto [key, value] = header.split(kHeaderFieldSeparator); - // 'Content-Length' is the only meaningful key at the moment. Others are - // ignored. - if (!key.equals_insensitive(kHeaderContentLength)) - continue; - - value = value.trim(); - if (!llvm::to_integer(value, content_length, 10)) - return createStringError(std::errc::invalid_argument, - "invalid content length: %s", - value.str().c_str()); - } - - // Check if we have enough data. - if (content_length > rest.size()) - break; - - StringRef body = rest.take_front(content_length); - buffer = rest.drop_front(content_length); - messages.emplace_back(body.str()); - Logv("--> {0}", body); - } - - // Store the remainder of the buffer for the next read callback. - m_buffer = buffer.str(); - - return std::move(messages); -} - -Error HTTPDelimitedJSONTransport::WriteImpl(const std::string &message) { - if (!m_output || !m_output->IsValid()) - return llvm::make_error(); - - Logv("<-- {0}", message); - - std::string Output; - raw_string_ostream OS(Output); - OS << kHeaderContentLength << kHeaderFieldSeparator << ' ' << message.length() - << kHeaderSeparator << kHeaderSeparator << message; - size_t num_bytes = Output.size(); - return m_output->Write(Output.data(), num_bytes).takeError(); -} - -Expected> JSONRPCTransport::Parse() { - std::vector messages; - StringRef buf = m_buffer; - while (buf.contains(kMessageSeparator)) { - auto [raw_json, rest] = buf.split(kMessageSeparator); - buf = rest; - messages.emplace_back(raw_json.str()); - Logv("--> {0}", raw_json); - } - - // Store the remainder of the buffer for the next read callback. - m_buffer = buf.str(); - - return messages; -} - -Error JSONRPCTransport::WriteImpl(const std::string &message) { - if (!m_output || !m_output->IsValid()) - return llvm::make_error(); - - Logv("<-- {0}", message); - - std::string Output; - llvm::raw_string_ostream OS(Output); - OS << message << kMessageSeparator; - size_t num_bytes = Output.size(); - return m_output->Write(Output.data(), num_bytes).takeError(); -} - -char TransportEOFError::ID; -char TransportUnhandledContentsError::ID; -char TransportInvalidError::ID; diff --git a/lldb/source/Protocol/MCP/Protocol.cpp b/lldb/source/Protocol/MCP/Protocol.cpp index d9b11bd766686..65ddfaee70160 100644 --- a/lldb/source/Protocol/MCP/Protocol.cpp +++ b/lldb/source/Protocol/MCP/Protocol.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "lldb/Protocol/MCP/Protocol.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/JSON.h" using namespace llvm; diff --git a/lldb/tools/lldb-dap/DAP.cpp b/lldb/tools/lldb-dap/DAP.cpp index ce910b1f60b85..a9a0fe75a35b7 100644 --- a/lldb/tools/lldb-dap/DAP.cpp +++ b/lldb/tools/lldb-dap/DAP.cpp @@ -121,11 +121,12 @@ static std::string capitalize(llvm::StringRef str) { llvm::StringRef DAP::debug_adapter_path = ""; DAP::DAP(Log *log, const ReplMode default_repl_mode, - std::vector pre_init_commands, Transport &transport) + std::vector pre_init_commands, + llvm::StringRef client_name, DAPTransport &transport, MainLoop &loop) : log(log), transport(transport), broadcaster("lldb-dap"), progress_event_reporter( [&](const ProgressEvent &event) { SendJSON(event.ToJSON()); }), - repl_mode(default_repl_mode) { + repl_mode(default_repl_mode), m_client_name(client_name), m_loop(loop) { configuration.preInitCommands = std::move(pre_init_commands); RegisterRequests(); } @@ -258,36 +259,33 @@ void DAP::SendJSON(const llvm::json::Value &json) { llvm::json::Path::Root root; if (!fromJSON(json, message, root)) { DAP_LOG_ERROR(log, root.getError(), "({1}) encoding failed: {0}", - transport.GetClientName()); + m_client_name); return; } Send(message); } void DAP::Send(const Message &message) { - // FIXME: After all the requests have migrated from LegacyRequestHandler > - // RequestHandler<> this should be handled in RequestHandler<>::operator(). - if (auto *resp = std::get_if(&message); - resp && debugger.InterruptRequested()) { - // Clear the interrupt request. - debugger.CancelInterruptRequest(); - - // If the debugger was interrupted, convert this response into a 'cancelled' - // response because we might have a partial result. - Response cancelled{/*request_seq=*/resp->request_seq, - /*command=*/resp->command, - /*success=*/false, - /*message=*/eResponseMessageCancelled, - /*body=*/std::nullopt}; - if (llvm::Error err = transport.Write(cancelled)) - DAP_LOG_ERROR(log, std::move(err), "({1}) write failed: {0}", - transport.GetClientName()); - return; + if (const protocol::Event *event = std::get_if(&message)) { + transport.Event(*event); + } else if (const Request *req = std::get_if(&message)) { + transport.Request(*req); + } else if (const Response *resp = std::get_if(&message)) { + // FIXME: After all the requests have migrated from LegacyRequestHandler > + // RequestHandler<> this should be handled in RequestHandler<>::operator(). + if (debugger.InterruptRequested()) + // If the debugger was interrupted, convert this response into a + // 'cancelled' response because we might have a partial result. + transport.Response(Response{/*request_seq=*/resp->request_seq, + /*command=*/resp->command, + /*success=*/false, + /*message=*/eResponseMessageCancelled, + /*body=*/std::nullopt}); + else + transport.Response(*resp); + } else { + llvm_unreachable("Unexpected message type"); } - - if (llvm::Error err = transport.Write(message)) - DAP_LOG_ERROR(log, std::move(err), "({1}) write failed: {0}", - transport.GetClientName()); } // "OutputEvent": { @@ -755,7 +753,6 @@ void DAP::RunTerminateCommands() { } lldb::SBTarget DAP::CreateTarget(lldb::SBError &error) { - // Grab the name of the program we need to debug and create a target using // the given program as an argument. Executable file can be a source of target // architecture and platform, if they differ from the host. Setting exe path // in launch info is useless because Target.Launch() will not change @@ -795,7 +792,7 @@ void DAP::SetTarget(const lldb::SBTarget target) { bool DAP::HandleObject(const Message &M) { TelemetryDispatcher dispatcher(&debugger); - dispatcher.Set("client_name", transport.GetClientName().str()); + dispatcher.Set("client_name", m_client_name.str()); if (const auto *req = std::get_if(&M)) { { std::lock_guard guard(m_active_request_mutex); @@ -821,8 +818,8 @@ bool DAP::HandleObject(const Message &M) { dispatcher.Set("error", llvm::Twine("unhandled-command:" + req->command).str()); - DAP_LOG(log, "({0}) error: unhandled command '{1}'", - transport.GetClientName(), req->command); + DAP_LOG(log, "({0}) error: unhandled command '{1}'", m_client_name, + req->command); return false; // Fail } @@ -920,8 +917,6 @@ llvm::Error DAP::Disconnect(bool terminateDebuggee) { SendTerminatedEvent(); disconnecting = true; - m_loop.AddPendingCallback( - [](MainLoopBase &loop) { loop.RequestTermination(); }); return ToError(error); } @@ -938,88 +933,74 @@ void DAP::ClearCancelRequest(const CancelArguments &args) { } template -static std::optional getArgumentsIfRequest(const Message &pm, +static std::optional getArgumentsIfRequest(const Request &req, llvm::StringLiteral command) { - auto *const req = std::get_if(&pm); - if (!req || req->command != command) + if (req.command != command) return std::nullopt; T args; llvm::json::Path::Root root; - if (!fromJSON(req->arguments, args, root)) + if (!fromJSON(req.arguments, args, root)) return std::nullopt; return args; } -Status DAP::TransportHandler() { - llvm::set_thread_name(transport.GetClientName() + ".transport_handler"); +void DAP::OnEvent(const protocol::Event &event) { + // no-op, no supported events from the client to the server as of DAP v1.68. +} - auto cleanup = llvm::make_scope_exit([&]() { - // Ensure we're marked as disconnecting when the reader exits. +void DAP::OnRequest(const protocol::Request &request) { + if (request.command == "disconnect") disconnecting = true; - m_queue_cv.notify_all(); - }); - Status status; - auto handle = transport.RegisterReadObject( - m_loop, - [&](MainLoopBase &loop, llvm::Expected message) { - if (message.errorIsA()) { - llvm::consumeError(message.takeError()); - loop.RequestTermination(); - return; - } + const std::optional cancel_args = + getArgumentsIfRequest(request, "cancel"); + if (cancel_args) { + { + std::lock_guard guard(m_cancelled_requests_mutex); + if (cancel_args->requestId) + m_cancelled_requests.insert(*cancel_args->requestId); + } - if (llvm::Error err = message.takeError()) { - status = Status::FromError(std::move(err)); - loop.RequestTermination(); - return; - } + // If a cancel is requested for the active request, make a best + // effort attempt to interrupt. + std::lock_guard guard(m_active_request_mutex); + if (m_active_request && cancel_args->requestId == m_active_request->seq) { + DAP_LOG(log, "({0}) interrupting inflight request (command={1} seq={2})", + m_client_name, m_active_request->command, m_active_request->seq); + debugger.RequestInterrupt(); + } + } - if (const protocol::Request *req = - std::get_if(&*message); - req && req->arguments == "disconnect") - disconnecting = true; - - const std::optional cancel_args = - getArgumentsIfRequest(*message, "cancel"); - if (cancel_args) { - { - std::lock_guard guard(m_cancelled_requests_mutex); - if (cancel_args->requestId) - m_cancelled_requests.insert(*cancel_args->requestId); - } + std::lock_guard guard(m_queue_mutex); + DAP_LOG(log, "({0}) queued (command={1} seq={2})", m_client_name, + request.command, request.seq); + m_queue.push_back(request); + m_queue_cv.notify_one(); +} - // If a cancel is requested for the active request, make a best - // effort attempt to interrupt. - std::lock_guard guard(m_active_request_mutex); - if (m_active_request && - cancel_args->requestId == m_active_request->seq) { - DAP_LOG(log, - "({0}) interrupting inflight request (command={1} seq={2})", - transport.GetClientName(), m_active_request->command, - m_active_request->seq); - debugger.RequestInterrupt(); - } - } +void DAP::OnResponse(const protocol::Response &response) { + std::lock_guard guard(m_queue_mutex); + DAP_LOG(log, "({0}) queued (command={1} seq={2})", m_client_name, + response.command, response.request_seq); + m_queue.push_back(response); + m_queue_cv.notify_one(); +} - std::lock_guard guard(m_queue_mutex); - m_queue.push_back(std::move(*message)); - m_queue_cv.notify_one(); - }); - if (auto err = handle.takeError()) - return Status::FromError(std::move(err)); - if (llvm::Error err = m_loop.Run().takeError()) - return Status::FromError(std::move(err)); - return status; +void DAP::TransportHandler(llvm::Error *error) { + llvm::ErrorAsOutParameter ErrAsOutParam(*error); + auto cleanup = llvm::make_scope_exit([&]() { + // Ensure we're marked as disconnecting when the reader exits. + disconnecting = true; + m_queue_cv.notify_all(); + }); + *error = transport.Run(m_loop, *this); } llvm::Error DAP::Loop() { - // Can't use \a std::future because it doesn't compile on - // Windows. - std::future queue_reader = - std::async(std::launch::async, &DAP::TransportHandler, this); + llvm::Error error = llvm::Error::success(); + auto thread = std::thread(std::bind(&DAP::TransportHandler, this, &error)); auto cleanup = llvm::make_scope_exit([&]() { out.Stop(); @@ -1045,7 +1026,11 @@ llvm::Error DAP::Loop() { "unhandled packet"); } - return queue_reader.get().takeError(); + m_loop.AddPendingCallback( + [](MainLoopBase &loop) { loop.RequestTermination(); }); + thread.join(); + + return error; } lldb::SBError DAP::WaitForProcessToStop(std::chrono::seconds seconds) { @@ -1284,7 +1269,7 @@ void DAP::ProgressEventThread() { // them prevent multiple threads from writing simultaneously so no locking // is required. void DAP::EventThread() { - llvm::set_thread_name(transport.GetClientName() + ".event_handler"); + llvm::set_thread_name("lldb.DAP.client." + m_client_name + ".event_handler"); lldb::SBEvent event; lldb::SBListener listener = debugger.GetListener(); broadcaster.AddListener(listener, eBroadcastBitStopEventThread); @@ -1316,7 +1301,7 @@ void DAP::EventThread() { if (llvm::Error err = SendThreadStoppedEvent(*this)) DAP_LOG_ERROR(log, std::move(err), "({1}) reporting thread stopped: {0}", - transport.GetClientName()); + m_client_name); } break; case lldb::eStateRunning: diff --git a/lldb/tools/lldb-dap/DAP.h b/lldb/tools/lldb-dap/DAP.h index b0e9fa9c16b75..628f97257d5f0 100644 --- a/lldb/tools/lldb-dap/DAP.h +++ b/lldb/tools/lldb-dap/DAP.h @@ -78,12 +78,16 @@ enum DAPBroadcasterBits { enum class ReplMode { Variable = 0, Command, Auto }; -struct DAP { +using DAPTransport = + lldb_private::Transport; + +struct DAP final : private DAPTransport::MessageHandler { /// Path to the lldb-dap binary itself. static llvm::StringRef debug_adapter_path; Log *log; - Transport &transport; + DAPTransport &transport; lldb::SBFile in; OutputRedirector out; OutputRedirector err; @@ -177,8 +181,11 @@ struct DAP { /// allocated. /// \param[in] transport /// Transport for this debug session. + /// \param[in] loop + /// Main loop associated with this instance. DAP(Log *log, const ReplMode default_repl_mode, - std::vector pre_init_commands, Transport &transport); + std::vector pre_init_commands, llvm::StringRef client_name, + DAPTransport &transport, lldb_private::MainLoop &loop); ~DAP(); @@ -317,7 +324,7 @@ struct DAP { lldb::SBTarget CreateTarget(lldb::SBError &error); /// Set given target object as a current target for lldb-dap and start - /// listeing for its breakpoint events. + /// listening for its breakpoint events. void SetTarget(const lldb::SBTarget target); bool HandleObject(const protocol::Message &M); @@ -420,13 +427,17 @@ struct DAP { const std::optional> &breakpoints); + void OnEvent(const protocol::Event &) override; + void OnRequest(const protocol::Request &) override; + void OnResponse(const protocol::Response &) override; + private: std::vector SetSourceBreakpoints( const protocol::Source &source, const std::optional> &breakpoints, SourceBreakpointMap &existing_breakpoints); - lldb_private::Status TransportHandler(); + void TransportHandler(llvm::Error *); /// Registration of request handler. /// @{ @@ -446,6 +457,8 @@ struct DAP { std::thread progress_event_thread; /// @} + const llvm::StringRef m_client_name; + /// List of addresses mapped by sourceReference. std::vector m_source_references; std::mutex m_source_references_mutex; @@ -456,7 +469,7 @@ struct DAP { std::condition_variable m_queue_cv; // Loop for managing reading from the client. - lldb_private::MainLoop m_loop; + lldb_private::MainLoop &m_loop; std::mutex m_cancelled_requests_mutex; llvm::SmallSet m_cancelled_requests; diff --git a/lldb/tools/lldb-dap/Protocol/ProtocolBase.h b/lldb/tools/lldb-dap/Protocol/ProtocolBase.h index 81496380d412f..0a9ef538a7398 100644 --- a/lldb/tools/lldb-dap/Protocol/ProtocolBase.h +++ b/lldb/tools/lldb-dap/Protocol/ProtocolBase.h @@ -52,6 +52,7 @@ struct Request { }; llvm::json::Value toJSON(const Request &); bool fromJSON(const llvm::json::Value &, Request &, llvm::json::Path); +bool operator==(const Request &, const Request &); /// A debug adapter initiated event. struct Event { @@ -63,6 +64,7 @@ struct Event { }; llvm::json::Value toJSON(const Event &); bool fromJSON(const llvm::json::Value &, Event &, llvm::json::Path); +bool operator==(const Event &, const Event &); enum ResponseMessage : unsigned { /// The request was cancelled @@ -101,6 +103,7 @@ struct Response { }; bool fromJSON(const llvm::json::Value &, Response &, llvm::json::Path); llvm::json::Value toJSON(const Response &); +bool operator==(const Response &, const Response &); /// A structured message object. Used to return errors from requests. struct ErrorMessage { @@ -140,6 +143,7 @@ llvm::json::Value toJSON(const ErrorMessage &); using Message = std::variant; bool fromJSON(const llvm::json::Value &, Message &, llvm::json::Path); llvm::json::Value toJSON(const Message &); +bool operator==(const Message &, const Message &); inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, const Message &V) { OS << toJSON(V); diff --git a/lldb/tools/lldb-dap/Transport.cpp b/lldb/tools/lldb-dap/Transport.cpp index d602920da34e3..8f71f88cae1f7 100644 --- a/lldb/tools/lldb-dap/Transport.cpp +++ b/lldb/tools/lldb-dap/Transport.cpp @@ -14,7 +14,8 @@ using namespace llvm; using namespace lldb; using namespace lldb_private; -using namespace lldb_dap; + +namespace lldb_dap { Transport::Transport(llvm::StringRef client_name, lldb_dap::Log *log, lldb::IOObjectSP input, lldb::IOObjectSP output) @@ -24,3 +25,5 @@ Transport::Transport(llvm::StringRef client_name, lldb_dap::Log *log, void Transport::Log(llvm::StringRef message) { DAP_LOG(m_log, "({0}) {1}", m_client_name, message); } + +} // namespace lldb_dap diff --git a/lldb/tools/lldb-dap/Transport.h b/lldb/tools/lldb-dap/Transport.h index 9a7d8f424d40e..efeb0b9cd6c55 100644 --- a/lldb/tools/lldb-dap/Transport.h +++ b/lldb/tools/lldb-dap/Transport.h @@ -15,6 +15,7 @@ #define LLDB_TOOLS_LLDB_DAP_TRANSPORT_H #include "DAPForward.h" +#include "Protocol/ProtocolBase.h" #include "lldb/Host/JSONTransport.h" #include "lldb/lldb-forward.h" #include "llvm/ADT/StringRef.h" @@ -23,7 +24,9 @@ namespace lldb_dap { /// A transport class that performs the Debug Adapter Protocol communication /// with the client. -class Transport : public lldb_private::HTTPDelimitedJSONTransport { +class Transport final + : public lldb_private::HTTPDelimitedJSONTransport< + protocol::Request, protocol::Response, protocol::Event> { public: Transport(llvm::StringRef client_name, lldb_dap::Log *log, lldb::IOObjectSP input, lldb::IOObjectSP output); diff --git a/lldb/tools/lldb-dap/tool/lldb-dap.cpp b/lldb/tools/lldb-dap/tool/lldb-dap.cpp index 8bba4162aa7bf..c728b0af94c7c 100644 --- a/lldb/tools/lldb-dap/tool/lldb-dap.cpp +++ b/lldb/tools/lldb-dap/tool/lldb-dap.cpp @@ -284,7 +284,7 @@ serveConnection(const Socket::SocketProtocol &protocol, const std::string &name, }); std::condition_variable dap_sessions_condition; std::mutex dap_sessions_mutex; - std::map dap_sessions; + std::map dap_sessions; unsigned int clientCount = 0; auto handle = listener->Accept(g_loop, [=, &dap_sessions_condition, &dap_sessions_mutex, &dap_sessions, @@ -300,8 +300,10 @@ serveConnection(const Socket::SocketProtocol &protocol, const std::string &name, std::thread client([=, &dap_sessions_condition, &dap_sessions_mutex, &dap_sessions]() { llvm::set_thread_name(client_name + ".runloop"); + MainLoop loop; Transport transport(client_name, log, io, io); - DAP dap(log, default_repl_mode, pre_init_commands, transport); + DAP dap(log, default_repl_mode, pre_init_commands, client_name, transport, + loop); if (auto Err = dap.ConfigureIO()) { llvm::logAllUnhandledErrors(std::move(Err), llvm::errs(), @@ -311,7 +313,7 @@ serveConnection(const Socket::SocketProtocol &protocol, const std::string &name, { std::scoped_lock lock(dap_sessions_mutex); - dap_sessions[io.get()] = &dap; + dap_sessions[&loop] = &dap; } if (auto Err = dap.Loop()) { @@ -322,7 +324,7 @@ serveConnection(const Socket::SocketProtocol &protocol, const std::string &name, DAP_LOG(log, "({0}) client disconnected", client_name); std::unique_lock lock(dap_sessions_mutex); - dap_sessions.erase(io.get()); + dap_sessions.erase(&loop); std::notify_all_at_thread_exit(dap_sessions_condition, std::move(lock)); }); client.detach(); @@ -344,13 +346,14 @@ serveConnection(const Socket::SocketProtocol &protocol, const std::string &name, bool client_failed = false; { std::scoped_lock lock(dap_sessions_mutex); - for (auto [sock, dap] : dap_sessions) { + for (auto [loop, dap] : dap_sessions) { if (llvm::Error error = dap->Disconnect()) { client_failed = true; - llvm::errs() << "DAP client " << dap->transport.GetClientName() - << " disconnected failed: " + llvm::errs() << "DAP client disconnected failed: " << llvm::toString(std::move(error)) << "\n"; } + loop->AddPendingCallback( + [](MainLoopBase &loop) { loop.RequestTermination(); }); } } @@ -550,8 +553,10 @@ int main(int argc, char *argv[]) { stdout_fd, File::eOpenOptionWriteOnly, NativeFile::Unowned); constexpr llvm::StringLiteral client_name = "stdio"; + MainLoop loop; Transport transport(client_name, log.get(), input, output); - DAP dap(log.get(), default_repl_mode, pre_init_commands, transport); + DAP dap(log.get(), default_repl_mode, pre_init_commands, client_name, + transport, loop); // stdout/stderr redirection to the IDE's console if (auto Err = dap.ConfigureIO(stdout, stderr)) { diff --git a/lldb/unittests/DAP/DAPTest.cpp b/lldb/unittests/DAP/DAPTest.cpp index 138910d917424..744e6e69a8d33 100644 --- a/lldb/unittests/DAP/DAPTest.cpp +++ b/lldb/unittests/DAP/DAPTest.cpp @@ -11,7 +11,6 @@ #include "TestBase.h" #include "llvm/Testing/Support/Error.h" #include "gtest/gtest.h" -#include #include using namespace llvm; @@ -27,12 +26,15 @@ TEST_F(DAPTest, SendProtocolMessages) { /*log=*/nullptr, /*default_repl_mode=*/ReplMode::Auto, /*pre_init_commands=*/{}, - /*transport=*/*to_dap, + /*client_name=*/"test_client", + /*transport=*/*transport, + loop, }; dap.Send(Event{/*event=*/"my-event", /*body=*/std::nullopt}); - RunOnce([&](llvm::Expected message) { - ASSERT_THAT_EXPECTED( - message, HasValue(testing::VariantWith(testing::FieldsAre( - /*event=*/"my-event", /*body=*/std::nullopt)))); - }); + loop.AddPendingCallback( + [](lldb_private::MainLoopBase &loop) { loop.RequestTermination(); }); + ASSERT_THAT_ERROR(dap.Loop(), llvm::Succeeded()); + ASSERT_THAT(from_dap, + ElementsAre(testing::VariantWith(testing::FieldsAre( + /*event=*/"my-event", /*body=*/std::nullopt)))); } diff --git a/lldb/unittests/DAP/Handler/DisconnectTest.cpp b/lldb/unittests/DAP/Handler/DisconnectTest.cpp index 0546aeb154d50..5b082151680dd 100644 --- a/lldb/unittests/DAP/Handler/DisconnectTest.cpp +++ b/lldb/unittests/DAP/Handler/DisconnectTest.cpp @@ -31,8 +31,8 @@ TEST_F(DisconnectRequestHandlerTest, DisconnectTriggersTerminated) { EXPECT_FALSE(dap->disconnecting); ASSERT_THAT_ERROR(handler.Run(std::nullopt), Succeeded()); EXPECT_TRUE(dap->disconnecting); - std::vector messages = DrainOutput(); - EXPECT_THAT(messages, + RunOnce(); + EXPECT_THAT(from_dap, testing::Contains(testing::VariantWith(testing::FieldsAre( /*event=*/"terminated", /*body=*/testing::_)))); } @@ -53,11 +53,13 @@ TEST_F(DisconnectRequestHandlerTest, DisconnectTriggersTerminateCommands) { EXPECT_EQ(dap->target.GetProcess().GetState(), lldb::eStateStopped); ASSERT_THAT_ERROR(handler.Run(std::nullopt), Succeeded()); EXPECT_TRUE(dap->disconnecting); - std::vector messages = DrainOutput(); - EXPECT_THAT(messages, testing::ElementsAre( - OutputMatcher("Running terminateCommands:\n"), - OutputMatcher("(lldb) script print(2)\n"), - OutputMatcher("2\n"), - testing::VariantWith(testing::FieldsAre( - /*event=*/"terminated", /*body=*/testing::_)))); + RunOnce(); + EXPECT_THAT(from_dap, + testing::Contains(OutputMatcher("Running terminateCommands:\n"))); + EXPECT_THAT(from_dap, + testing::Contains(OutputMatcher("(lldb) script print(2)\n"))); + EXPECT_THAT(from_dap, testing::Contains(OutputMatcher("2\n"))); + EXPECT_THAT(from_dap, + testing::Contains(testing::VariantWith(testing::FieldsAre( + /*event=*/"terminated", /*body=*/testing::_)))); } diff --git a/lldb/unittests/DAP/TestBase.cpp b/lldb/unittests/DAP/TestBase.cpp index 8f9b098c8b1e1..64097d177c4a9 100644 --- a/lldb/unittests/DAP/TestBase.cpp +++ b/lldb/unittests/DAP/TestBase.cpp @@ -7,14 +7,11 @@ //===----------------------------------------------------------------------===// #include "TestBase.h" -#include "Protocol/ProtocolBase.h" #include "TestingSupport/TestUtilities.h" #include "lldb/API/SBDefines.h" #include "lldb/API/SBStructuredData.h" -#include "lldb/Host/File.h" #include "lldb/Host/MainLoop.h" #include "lldb/Host/Pipe.h" -#include "lldb/lldb-forward.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Error.h" #include "llvm/Testing/Support/Error.h" @@ -26,39 +23,17 @@ using namespace lldb; using namespace lldb_dap; using namespace lldb_dap::protocol; using namespace lldb_dap_tests; -using lldb_private::File; using lldb_private::MainLoop; -using lldb_private::MainLoopBase; -using lldb_private::NativeFile; using lldb_private::Pipe; -void TransportBase::SetUp() { - PipePairTest::SetUp(); - to_dap = std::make_unique( - "to_dap", nullptr, - std::make_shared(input.GetReadFileDescriptor(), - File::eOpenOptionReadOnly, - NativeFile::Unowned), - std::make_shared(output.GetWriteFileDescriptor(), - File::eOpenOptionWriteOnly, - NativeFile::Unowned)); - from_dap = std::make_unique( - "from_dap", nullptr, - std::make_shared(output.GetReadFileDescriptor(), - File::eOpenOptionReadOnly, - NativeFile::Unowned), - std::make_shared(input.GetWriteFileDescriptor(), - File::eOpenOptionWriteOnly, - NativeFile::Unowned)); -} - void DAPTestBase::SetUp() { TransportBase::SetUp(); dap = std::make_unique( /*log=*/nullptr, /*default_repl_mode=*/ReplMode::Auto, /*pre_init_commands=*/std::vector(), - /*transport=*/*to_dap); + /*client_name=*/"test_client", + /*transport=*/*transport, /*loop=*/loop); } void DAPTestBase::TearDown() { @@ -118,22 +93,3 @@ void DAPTestBase::LoadCore() { SBProcess process = dap->target.LoadCore(this->core->TmpName.data()); ASSERT_TRUE(process); } - -std::vector DAPTestBase::DrainOutput() { - std::vector msgs; - output.CloseWriteFileDescriptor(); - auto handle = from_dap->RegisterReadObject( - loop, [&](MainLoopBase &loop, Expected next) { - if (llvm::Error error = next.takeError()) { - loop.RequestTermination(); - consumeError(std::move(error)); - return; - } - - msgs.push_back(*next); - }); - - consumeError(handle.takeError()); - consumeError(loop.Run().takeError()); - return msgs; -} diff --git a/lldb/unittests/DAP/TestBase.h b/lldb/unittests/DAP/TestBase.h index afdfb540d39b8..4591c0fc72726 100644 --- a/lldb/unittests/DAP/TestBase.h +++ b/lldb/unittests/DAP/TestBase.h @@ -8,41 +8,81 @@ #include "DAP.h" #include "Protocol/ProtocolBase.h" -#include "TestingSupport/Host/PipeTestUtilities.h" -#include "Transport.h" #include "lldb/Host/MainLoop.h" +#include "lldb/Host/MainLoopBase.h" #include "llvm/ADT/StringRef.h" +#include "llvm/Testing/Support/Error.h" #include "gmock/gmock.h" #include "gtest/gtest.h" namespace lldb_dap_tests { +class TestTransport final + : public lldb_private::Transport { +public: + using Message = lldb_private::Transport::Message; + + TestTransport(lldb_private::MainLoop &loop, MessageHandler &handler) + : m_loop(loop), m_handler(handler) {} + + void Event(const lldb_dap::protocol::Event &e) override { + m_loop.AddPendingCallback([this, e](lldb_private::MainLoopBase &) { + this->m_handler.OnEvent(e); + }); + } + + void Request(const lldb_dap::protocol::Request &r) override { + m_loop.AddPendingCallback([this, r](lldb_private::MainLoopBase &) { + this->m_handler.OnRequest(r); + }); + } + + void Response(const lldb_dap::protocol::Response &r) override { + m_loop.AddPendingCallback([this, r](lldb_private::MainLoopBase &) { + this->m_handler.OnResponse(r); + }); + } + + llvm::Error Run(lldb_private::MainLoop &loop, MessageHandler &) override { + return loop.Run().takeError(); + } + + void Log(llvm::StringRef message) override { + log_messages.emplace_back(message); + } + + std::vector log_messages; + +private: + lldb_private::MainLoop &m_loop; + MessageHandler &m_handler; +}; + /// A base class for tests that need transport configured for communicating DAP /// messages. -class TransportBase : public PipePairTest { +class TransportBase : public testing::Test, + public TestTransport::MessageHandler { protected: - std::unique_ptr to_dap; - std::unique_ptr from_dap; + std::vector from_dap; lldb_private::MainLoop loop; + std::unique_ptr transport; - void SetUp() override; + void SetUp() override { + transport = std::make_unique(loop, *this); + } - template - void RunOnce(const std::function)> &callback, - std::chrono::milliseconds timeout = std::chrono::seconds(1)) { - auto handle = from_dap->RegisterReadObject

( - loop, [&](lldb_private::MainLoopBase &loop, llvm::Expected

message) { - callback(std::move(message)); - loop.RequestTermination(); - }); - loop.AddCallback( - [](lldb_private::MainLoopBase &loop) { - loop.RequestTermination(); - FAIL() << "timeout waiting for read callback"; - }, - timeout); - ASSERT_THAT_EXPECTED(handle, llvm::Succeeded()); - ASSERT_THAT_ERROR(loop.Run().takeError(), llvm::Succeeded()); + void OnEvent(const lldb_dap::protocol::Event &e) override { + from_dap.emplace_back(e); + } + void OnRequest(const lldb_dap::protocol::Request &r) override { + from_dap.emplace_back(r); + } + void OnResponse(const lldb_dap::protocol::Response &r) override { + from_dap.emplace_back(r); } }; @@ -75,7 +115,12 @@ class DAPTestBase : public TransportBase { /// Closes the DAP output pipe and returns the remaining protocol messages in /// the buffer. - std::vector DrainOutput(); + // std::vector DrainOutput(); + void RunOnce() { + loop.AddPendingCallback( + [](lldb_private::MainLoopBase &loop) { loop.RequestTermination(); }); + ASSERT_THAT_ERROR(dap->Loop(), llvm::Succeeded()); + } }; } // namespace lldb_dap_tests diff --git a/lldb/unittests/Host/JSONTransportTest.cpp b/lldb/unittests/Host/JSONTransportTest.cpp index 4e94582d3bc6a..fdfde328c69b7 100644 --- a/lldb/unittests/Host/JSONTransportTest.cpp +++ b/lldb/unittests/Host/JSONTransportTest.cpp @@ -11,15 +11,18 @@ #include "lldb/Host/File.h" #include "lldb/Host/MainLoop.h" #include "lldb/Host/MainLoopBase.h" -#include "llvm/ADT/FunctionExtras.h" +#include "lldb/Utility/Log.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Error.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/FormatVariadic.h" #include "llvm/Support/JSON.h" +#include "llvm/Support/raw_ostream.h" #include "llvm/Testing/Support/Error.h" +#include "gmock/gmock.h" #include "gtest/gtest.h" #include #include -#include #include #include @@ -28,22 +31,119 @@ using namespace lldb_private; namespace { -struct JSONTestType { - std::string str; +namespace test_protocol { + +struct Req { + std::string name; }; +json::Value toJSON(const Req &T) { return json::Object{{"req", T.name}}; } +bool fromJSON(const json::Value &V, Req &T, json::Path P) { + json::ObjectMapper O(V, P); + return O && O.map("req", T.name); +} +bool operator==(const Req &a, const Req &b) { return a.name == b.name; } +inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, const Req &V) { + OS << toJSON(V); + return OS; +} +void PrintTo(const Req &message, std::ostream *os) { + std::string O; + llvm::raw_string_ostream OS(O); + OS << message; + *os << O; +} -json::Value toJSON(const JSONTestType &T) { - return json::Object{{"str", T.str}}; +struct Resp { + std::string name; +}; +json::Value toJSON(const Resp &T) { return json::Object{{"resp", T.name}}; } +bool fromJSON(const json::Value &V, Resp &T, json::Path P) { + json::ObjectMapper O(V, P); + return O && O.map("resp", T.name); +} +bool operator==(const Resp &a, const Resp &b) { return a.name == b.name; } +inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, const Resp &V) { + OS << toJSON(V); + return OS; +} +void PrintTo(const Resp &message, std::ostream *os) { + std::string O; + llvm::raw_string_ostream OS(O); + OS << message; + *os << O; } -bool fromJSON(const json::Value &V, JSONTestType &T, json::Path P) { +struct Evt { + std::string name; +}; +json::Value toJSON(const Evt &T) { return json::Object{{"evt", T.name}}; } +bool fromJSON(const json::Value &V, Evt &T, json::Path P) { json::ObjectMapper O(V, P); - return O && O.map("str", T.str); + return O && O.map("evt", T.name); +} +bool operator==(const Evt &a, const Evt &b) { return a.name == b.name; } +inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, const Evt &V) { + OS << toJSON(V); + return OS; +} +void PrintTo(const Evt &message, std::ostream *os) { + std::string O; + llvm::raw_string_ostream OS(O); + OS << message; + *os << O; +} + +using Message = std::variant; +json::Value toJSON(const Message &T) { + if (const Req *req = std::get_if(&T)) + return toJSON(*req); + if (const Resp *resp = std::get_if(&T)) + return toJSON(*resp); + if (const Evt *evt = std::get_if(&T)) + return toJSON(*evt); + llvm_unreachable("unknown message type"); +} +bool fromJSON(const json::Value &V, Message &T, json::Path P) { + const json::Object *O = V.getAsObject(); + if (!O) { + P.report("expected object"); + return false; + } + if (O->get("req")) { + Req R; + if (!fromJSON(V, R, P)) + return false; + + T = std::move(R); + return true; + } + if (O->get("resp")) { + Resp R; + if (!fromJSON(V, R, P)) + return false; + + T = std::move(R); + return true; + } + if (O->get("evt")) { + Evt E; + if (!fromJSON(V, E, P)) + return false; + + T = std::move(E); + return true; + } + P.report("unknown message type"); + return false; } -template class JSONTransportTest : public PipePairTest { +} // namespace test_protocol + +template +class JSONTransportTest : public PipePairTest { + protected: - std::unique_ptr transport; + std::unique_ptr transport; MainLoop loop; void SetUp() override { @@ -57,53 +157,59 @@ template class JSONTransportTest : public PipePairTest { NativeFile::Unowned)); } - template - Expected

- RunOnce(std::chrono::milliseconds timeout = std::chrono::seconds(1)) { - std::promise> promised_message; - std::future> future_message = promised_message.get_future(); - RunUntil

( - [&promised_message](Expected

message) mutable -> bool { - promised_message.set_value(std::move(message)); - return /*keep_going*/ false; - }, - timeout); - return future_message.get(); + class MessageCollector final + : public Transport::MessageHandler { + public: + std::vector messages; + void OnEvent(const Evt &V) override { messages.emplace_back(V); } + void OnRequest(const Req &V) override { messages.emplace_back(V); } + void OnResponse(const Resp &V) override { messages.emplace_back(V); } + }; + + /// Run the transport MainLoop and return any messages received. + Expected> + Run(std::chrono::milliseconds timeout = std::chrono::milliseconds(5000)) { + MessageCollector collector; + loop.AddCallback([](MainLoopBase &loop) { loop.RequestTermination(); }, + timeout); + if (auto error = transport->Run(loop, collector)) + return error; + return std::move(collector.messages); } - /// RunUntil runs the event loop until the callback returns `false` or a - /// timeout has occurred. - template - void RunUntil(std::function)> callback, - std::chrono::milliseconds timeout = std::chrono::seconds(1)) { - auto handle = transport->RegisterReadObject

( - loop, [&callback](MainLoopBase &loop, Expected

message) mutable { - bool keep_going = callback(std::move(message)); - if (!keep_going) - loop.RequestTermination(); - }); - loop.AddCallback( - [&callback](MainLoopBase &loop) mutable { - loop.RequestTermination(); - callback(createStringError("timeout")); - }, - timeout); - EXPECT_THAT_EXPECTED(handle, Succeeded()); - EXPECT_THAT_ERROR(loop.Run().takeError(), Succeeded()); - } - - template llvm::Expected Write(Ts... args) { + template void Write(Ts... args) { std::string message; for (const auto &arg : {args...}) message += Encode(arg); - return input.Write(message.data(), message.size()); + EXPECT_THAT_EXPECTED(input.Write(message.data(), message.size()), + Succeeded()); + } + + template void WriteAndCloseInput(Ts... args) { + Write(std::forward(args)...); + input.CloseWriteFileDescriptor(); } virtual std::string Encode(const json::Value &) = 0; }; +class TestHTTPDelimitedJSONTransport final + : public HTTPDelimitedJSONTransport { +public: + using HTTPDelimitedJSONTransport::HTTPDelimitedJSONTransport; + + void Log(llvm::StringRef message) override { + log_messages.emplace_back(message); + } + + std::vector log_messages; +}; + class HTTPDelimitedJSONTransportTest - : public JSONTransportTest { + : public JSONTransportTest { public: using JSONTransportTest::JSONTransportTest; @@ -118,7 +224,22 @@ class HTTPDelimitedJSONTransportTest } }; -class JSONRPCTransportTest : public JSONTransportTest { +class TestJSONRPCTransport final + : public JSONRPCTransport { +public: + using JSONRPCTransport::JSONRPCTransport; + + void Log(llvm::StringRef message) override { + log_messages.emplace_back(message); + } + + std::vector log_messages; +}; + +class JSONRPCTransportTest + : public JSONTransportTest { public: using JSONTransportTest::JSONTransportTest; @@ -134,6 +255,7 @@ class JSONRPCTransportTest : public JSONTransportTest { // Failing on Windows, see https://github.com/llvm/llvm-project/issues/153446. #ifndef _WIN32 +using namespace test_protocol; TEST_F(HTTPDelimitedJSONTransportTest, MalformedRequests) { std::string malformed_header = @@ -141,84 +263,65 @@ TEST_F(HTTPDelimitedJSONTransportTest, MalformedRequests) { ASSERT_THAT_EXPECTED( input.Write(malformed_header.data(), malformed_header.size()), Succeeded()); - ASSERT_THAT_EXPECTED(RunOnce(), - FailedWithMessage("invalid content length: -1")); + ASSERT_THAT_EXPECTED(Run(), FailedWithMessage("invalid content length: -1")); } TEST_F(HTTPDelimitedJSONTransportTest, Read) { - ASSERT_THAT_EXPECTED(Write(JSONTestType{"foo"}), Succeeded()); - ASSERT_THAT_EXPECTED(RunOnce(), - HasValue(testing::FieldsAre(/*str=*/"foo"))); + WriteAndCloseInput(Req{"foo"}); + ASSERT_THAT_EXPECTED(Run(), HasValue(testing::ElementsAre(Req{"foo"}))); } TEST_F(HTTPDelimitedJSONTransportTest, ReadMultipleMessagesInSingleWrite) { - ASSERT_THAT_EXPECTED(Write(JSONTestType{"one"}, JSONTestType{"two"}), - Succeeded()); - unsigned count = 0; - RunUntil([&](Expected message) -> bool { - if (count == 0) { - EXPECT_THAT_EXPECTED(message, - HasValue(testing::FieldsAre(/*str=*/"one"))); - } else if (count == 1) { - EXPECT_THAT_EXPECTED(message, - HasValue(testing::FieldsAre(/*str=*/"two"))); - } - - count++; - return count < 2; - }); + WriteAndCloseInput(Message{Req{"one"}}, Message{Resp{"two"}}, + Message{Evt{"three"}}); + EXPECT_THAT_EXPECTED(Run(), HasValue(testing::ElementsAre( + Req{"one"}, Resp{"two"}, Evt{"three"}))); } TEST_F(HTTPDelimitedJSONTransportTest, ReadAcrossMultipleChunks) { std::string long_str = std::string(2048, 'x'); - ASSERT_THAT_EXPECTED(Write(JSONTestType{long_str}), Succeeded()); - ASSERT_THAT_EXPECTED(RunOnce(), - HasValue(testing::FieldsAre(/*str=*/long_str))); + WriteAndCloseInput(Req{long_str}); + ASSERT_THAT_EXPECTED(Run(), HasValue(testing::ElementsAre(Req{long_str}))); } TEST_F(HTTPDelimitedJSONTransportTest, ReadPartialMessage) { - std::string message = Encode(JSONTestType{"foo"}); - std::string part1 = message.substr(0, 28); - std::string part2 = message.substr(28); + std::string message = Encode(Req{"foo"}); + auto split_at = message.size() / 2; + std::string part1 = message.substr(0, split_at); + std::string part2 = message.substr(split_at); ASSERT_THAT_EXPECTED(input.Write(part1.data(), part1.size()), Succeeded()); - - ASSERT_THAT_EXPECTED( - RunOnce(/*timeout=*/std::chrono::milliseconds(10)), - FailedWithMessage("timeout")); + ASSERT_THAT_EXPECTED(Run(/*timeout=*/std::chrono::milliseconds(10)), + HasValue(testing::IsEmpty())); ASSERT_THAT_EXPECTED(input.Write(part2.data(), part2.size()), Succeeded()); - - ASSERT_THAT_EXPECTED(RunOnce(), - HasValue(testing::FieldsAre(/*str=*/"foo"))); + input.CloseWriteFileDescriptor(); + ASSERT_THAT_EXPECTED(Run(), HasValue(testing::ElementsAre(Req{"foo"}))); } TEST_F(HTTPDelimitedJSONTransportTest, ReadWithZeroByteWrites) { - std::string message = Encode(JSONTestType{"foo"}); - std::string part1 = message.substr(0, 28); - std::string part2 = message.substr(28); + std::string message = Encode(Req{"foo"}); + auto split_at = message.size() / 2; + std::string part1 = message.substr(0, split_at); + std::string part2 = message.substr(split_at); ASSERT_THAT_EXPECTED(input.Write(part1.data(), part1.size()), Succeeded()); - ASSERT_THAT_EXPECTED( - RunOnce(/*timeout=*/std::chrono::milliseconds(10)), - FailedWithMessage("timeout")); + ASSERT_THAT_EXPECTED(Run(/*timeout=*/std::chrono::milliseconds(10)), + HasValue(testing::IsEmpty())); ASSERT_THAT_EXPECTED(input.Write(part1.data(), 0), Succeeded()); // zero-byte write. - - ASSERT_THAT_EXPECTED( - RunOnce(/*timeout=*/std::chrono::milliseconds(10)), - FailedWithMessage("timeout")); + ASSERT_THAT_EXPECTED(Run(/*timeout=*/std::chrono::milliseconds(10)), + HasValue(testing::IsEmpty())); ASSERT_THAT_EXPECTED(input.Write(part2.data(), part2.size()), Succeeded()); - - ASSERT_THAT_EXPECTED(RunOnce(), - HasValue(testing::FieldsAre(/*str=*/"foo"))); + input.CloseWriteFileDescriptor(); + ASSERT_THAT_EXPECTED(Run(), HasValue(testing::ElementsAre(Req{"foo"}))); } TEST_F(HTTPDelimitedJSONTransportTest, ReadWithEOF) { input.CloseWriteFileDescriptor(); - ASSERT_THAT_EXPECTED(RunOnce(), Failed()); + ASSERT_THAT_EXPECTED(Run(), HasValue(testing::IsEmpty())); } TEST_F(HTTPDelimitedJSONTransportTest, ReaderWithUnhandledData) { @@ -231,32 +334,35 @@ TEST_F(HTTPDelimitedJSONTransportTest, ReaderWithUnhandledData) { ASSERT_THAT_EXPECTED(input.Write(message.data(), message.size() - 1), Succeeded()); input.CloseWriteFileDescriptor(); - ASSERT_THAT_EXPECTED(RunOnce(), - Failed()); + ASSERT_THAT_EXPECTED(Run(), Failed()); } TEST_F(HTTPDelimitedJSONTransportTest, NoDataTimeout) { - ASSERT_THAT_EXPECTED( - RunOnce(/*timeout=*/std::chrono::milliseconds(10)), - FailedWithMessage("timeout")); + ASSERT_THAT_EXPECTED(Run(/*timeout=*/std::chrono::milliseconds(10)), + HasValue(testing::IsEmpty())); } TEST_F(HTTPDelimitedJSONTransportTest, InvalidTransport) { - transport = std::make_unique(nullptr, nullptr); - auto handle = transport->RegisterReadObject( - loop, [&](MainLoopBase &, llvm::Expected) {}); - ASSERT_THAT_EXPECTED(handle, FailedWithMessage("IO object is not valid.")); + transport = + std::make_unique(nullptr, nullptr); + ASSERT_THAT_EXPECTED(Run(), FailedWithMessage("IO object is not valid.")); } TEST_F(HTTPDelimitedJSONTransportTest, Write) { - ASSERT_THAT_ERROR(transport->Write(JSONTestType{"foo"}), Succeeded()); + transport->Request(Req{"foo"}); + transport->Response(Resp{"bar"}); + transport->Event(Evt{"baz"}); output.CloseWriteFileDescriptor(); char buf[1024]; Expected bytes_read = output.Read(buf, sizeof(buf), std::chrono::milliseconds(1)); ASSERT_THAT_EXPECTED(bytes_read, Succeeded()); ASSERT_EQ(StringRef(buf, *bytes_read), StringRef("Content-Length: 13\r\n\r\n" - R"json({"str":"foo"})json")); + R"({"req":"foo"})" + "Content-Length: 14\r\n\r\n" + R"({"resp":"bar"})" + "Content-Length: 13\r\n\r\n" + R"({"evt":"baz"})")); } TEST_F(JSONRPCTransportTest, MalformedRequests) { @@ -264,80 +370,80 @@ TEST_F(JSONRPCTransportTest, MalformedRequests) { ASSERT_THAT_EXPECTED( input.Write(malformed_header.data(), malformed_header.size()), Succeeded()); - ASSERT_THAT_EXPECTED(RunOnce(), llvm::Failed()); + ASSERT_THAT_EXPECTED( + Run(), FailedWithMessage("[1:2, byte=2]: Invalid JSON value (null?)")); } TEST_F(JSONRPCTransportTest, Read) { - ASSERT_THAT_EXPECTED(Write(JSONTestType{"foo"}), Succeeded()); - ASSERT_THAT_EXPECTED(RunOnce(), - HasValue(testing::FieldsAre(/*str=*/"foo"))); + WriteAndCloseInput(Message{Req{"foo"}}, Message{Resp{"bar"}}, + Message{Evt{"baz"}}); + ASSERT_THAT_EXPECTED(Run(), HasValue(testing::ElementsAre( + Req{"foo"}, Resp{"bar"}, Evt{"baz"}))); } TEST_F(JSONRPCTransportTest, ReadAcrossMultipleChunks) { - std::string long_str = std::string(2048, 'x'); - std::string message = Encode(JSONTestType{long_str}); - ASSERT_THAT_EXPECTED(input.Write(message.data(), message.size()), - Succeeded()); - ASSERT_THAT_EXPECTED(RunOnce(), - HasValue(testing::FieldsAre(/*str=*/long_str))); + // Use a string longer than the chunk size to ensure we split the message + // across the chunk boundary. + std::string long_str = + std::string(JSONTransport::kReadBufferSize + 10, 'x'); + WriteAndCloseInput(Req{long_str}); + ASSERT_THAT_EXPECTED(Run(), HasValue(testing::ElementsAre(Req{long_str}))); } TEST_F(JSONRPCTransportTest, ReadPartialMessage) { - std::string message = R"({"str": "foo"})" + std::string message = R"({"req": "foo"})" "\n"; std::string part1 = message.substr(0, 7); std::string part2 = message.substr(7); ASSERT_THAT_EXPECTED(input.Write(part1.data(), part1.size()), Succeeded()); - - ASSERT_THAT_EXPECTED( - RunOnce(/*timeout=*/std::chrono::milliseconds(10)), - FailedWithMessage("timeout")); + ASSERT_THAT_EXPECTED(Run(std::chrono::milliseconds(10)), + HasValue(testing::IsEmpty())); ASSERT_THAT_EXPECTED(input.Write(part2.data(), part2.size()), Succeeded()); - - ASSERT_THAT_EXPECTED(RunOnce(), - HasValue(testing::FieldsAre(/*str=*/"foo"))); + input.CloseWriteFileDescriptor(); + ASSERT_THAT_EXPECTED(Run(), HasValue(testing::ElementsAre(Req{"foo"}))); } TEST_F(JSONRPCTransportTest, ReadWithEOF) { input.CloseWriteFileDescriptor(); - ASSERT_THAT_EXPECTED(RunOnce(), Failed()); + ASSERT_THAT_EXPECTED(Run(), HasValue(testing::IsEmpty())); } TEST_F(JSONRPCTransportTest, ReaderWithUnhandledData) { - std::string message = R"json({"str": "foo"})json" - "\n"; + std::string message = R"json({"req": "foo")json"; // Write an incomplete message and close the handle. - ASSERT_THAT_EXPECTED(input.Write(message.data(), message.size() - 1), + ASSERT_THAT_EXPECTED(input.Write(message.data(), message.size()), Succeeded()); input.CloseWriteFileDescriptor(); - ASSERT_THAT_EXPECTED(RunOnce(), - Failed()); + EXPECT_THAT_EXPECTED(Run(), Failed()); } TEST_F(JSONRPCTransportTest, Write) { - ASSERT_THAT_ERROR(transport->Write(JSONTestType{"foo"}), Succeeded()); + transport->Request(Req{"foo"}); + transport->Response(Resp{"bar"}); + transport->Event(Evt{"baz"}); output.CloseWriteFileDescriptor(); char buf[1024]; Expected bytes_read = output.Read(buf, sizeof(buf), std::chrono::milliseconds(1)); ASSERT_THAT_EXPECTED(bytes_read, Succeeded()); - ASSERT_EQ(StringRef(buf, *bytes_read), StringRef(R"json({"str":"foo"})json" + ASSERT_EQ(StringRef(buf, *bytes_read), StringRef(R"({"req":"foo"})" + "\n" + R"({"resp":"bar"})" + "\n" + R"({"evt":"baz"})" "\n")); } TEST_F(JSONRPCTransportTest, InvalidTransport) { - transport = std::make_unique(nullptr, nullptr); - auto handle = transport->RegisterReadObject( - loop, [&](MainLoopBase &, llvm::Expected) {}); - ASSERT_THAT_EXPECTED(handle, FailedWithMessage("IO object is not valid.")); + transport = std::make_unique(nullptr, nullptr); + ASSERT_THAT_EXPECTED(Run(), FailedWithMessage("IO object is not valid.")); } TEST_F(JSONRPCTransportTest, NoDataTimeout) { - ASSERT_THAT_EXPECTED( - RunOnce(/*timeout=*/std::chrono::milliseconds(10)), - FailedWithMessage("timeout")); + ASSERT_THAT_EXPECTED(Run(/*timeout=*/std::chrono::milliseconds(10)), + HasValue(testing::ElementsAre())); } #endif diff --git a/lldb/unittests/ProtocolServer/ProtocolMCPServerTest.cpp b/lldb/unittests/ProtocolServer/ProtocolMCPServerTest.cpp index 2ac40c41dd28e..588093edf321a 100644 --- a/lldb/unittests/ProtocolServer/ProtocolMCPServerTest.cpp +++ b/lldb/unittests/ProtocolServer/ProtocolMCPServerTest.cpp @@ -21,7 +21,9 @@ #include "lldb/Protocol/MCP/MCPError.h" #include "lldb/Protocol/MCP/Protocol.h" #include "llvm/Support/Error.h" +#include "llvm/Support/JSON.h" #include "llvm/Testing/Support/Error.h" +#include "gmock/gmock.h" #include "gtest/gtest.h" #include #include @@ -43,11 +45,18 @@ class TestProtocolServerMCP : public lldb_private::mcp::ProtocolServerMCP { using ProtocolServerMCP::ProtocolServerMCP; }; -class TestJSONTransport : public lldb_private::JSONRPCTransport { +using Message = typename Transport::Message; + +class TestJSONTransport final + : public lldb_private::JSONRPCTransport { public: using JSONRPCTransport::JSONRPCTransport; - using JSONRPCTransport::Parse; - using JSONRPCTransport::WriteImpl; + + void Log(llvm::StringRef message) override { + log_messages.emplace_back(message); + } + + std::vector log_messages; }; /// Test tool that returns it argument as text. @@ -55,7 +64,7 @@ class TestTool : public Tool { public: using Tool::Tool; - virtual llvm::Expected Call(const ToolArguments &args) override { + llvm::Expected Call(const ToolArguments &args) override { std::string argument; if (const json::Object *args_obj = std::get(args).getAsObject()) { @@ -73,7 +82,7 @@ class TestTool : public Tool { class TestResourceProvider : public ResourceProvider { using ResourceProvider::ResourceProvider; - virtual std::vector GetResources() const override { + std::vector GetResources() const override { std::vector resources; Resource resource; @@ -86,7 +95,7 @@ class TestResourceProvider : public ResourceProvider { return resources; } - virtual llvm::Expected + llvm::Expected ReadResource(llvm::StringRef uri) const override { if (uri != "lldb://foo/bar") return llvm::make_error(uri.str()); @@ -107,7 +116,7 @@ class ErrorTool : public Tool { public: using Tool::Tool; - virtual llvm::Expected Call(const ToolArguments &args) override { + llvm::Expected Call(const ToolArguments &args) override { return llvm::createStringError("error"); } }; @@ -117,7 +126,7 @@ class FailTool : public Tool { public: using Tool::Tool; - virtual llvm::Expected Call(const ToolArguments &args) override { + llvm::Expected Call(const ToolArguments &args) override { TextResult text_result; text_result.content.emplace_back(TextContent{{"failed"}}); text_result.isError = true; @@ -138,26 +147,33 @@ class ProtocolServerMCPTest : public ::testing::Test { static constexpr llvm::StringLiteral k_localhost = "localhost"; llvm::Error Write(llvm::StringRef message) { - return m_transport_up->WriteImpl(llvm::formatv("{0}\n", message).str()); + std::string output = llvm::formatv("{0}\n", message).str(); + size_t bytes_written = output.size(); + return m_io_sp->Write(output.data(), bytes_written).takeError(); } - template - void - RunOnce(const std::function)> &callback, - std::chrono::milliseconds timeout = std::chrono::milliseconds(100)) { - auto handle = m_transport_up->RegisterReadObject

( - loop, [&](lldb_private::MainLoopBase &loop, llvm::Expected

message) { - callback(std::move(message)); - loop.RequestTermination(); - }); - loop.AddCallback( - [&](lldb_private::MainLoopBase &loop) { - loop.RequestTermination(); - FAIL() << "timeout waiting for read callback"; - }, - timeout); - ASSERT_THAT_EXPECTED(handle, llvm::Succeeded()); - ASSERT_THAT_ERROR(loop.Run().takeError(), llvm::Succeeded()); + void CloseInput() { + EXPECT_THAT_ERROR(m_io_sp->Close().takeError(), Succeeded()); + } + + class MessageCollector final + : public Transport::MessageHandler { + public: + std::vector messages; + void OnEvent(const Notification &V) override { messages.emplace_back(V); } + void OnRequest(const Request &V) override { messages.emplace_back(V); } + void OnResponse(const Response &V) override { messages.emplace_back(V); } + }; + + /// Run the transport MainLoop and return any messages received. + Expected> + Run(std::chrono::milliseconds timeout = std::chrono::milliseconds(100)) { + MessageCollector collector; + loop.AddCallback([](MainLoopBase &loop) { loop.RequestTermination(); }, + timeout); + if (auto error = m_transport_up->Run(loop, collector)) + return error; + return std::move(collector.messages); } void SetUp() override { @@ -206,37 +222,39 @@ TEST_F(ProtocolServerMCPTest, Initialization) { llvm::StringLiteral response = R"json( {"id":0,"jsonrpc":"2.0","result":{"capabilities":{"resources":{"listChanged":false,"subscribe":false},"tools":{"listChanged":true}},"protocolVersion":"2024-11-05","serverInfo":{"name":"lldb-mcp","version":"0.1.0"}}})json"; - ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); - RunOnce([&](llvm::Expected response_str) { - ASSERT_THAT_EXPECTED(response_str, llvm::Succeeded()); - llvm::Expected response_json = json::parse(*response_str); - ASSERT_THAT_EXPECTED(response_json, llvm::Succeeded()); - - llvm::Expected expected_json = json::parse(response); - ASSERT_THAT_EXPECTED(expected_json, llvm::Succeeded()); - - EXPECT_EQ(*response_json, *expected_json); - }); + ASSERT_THAT_ERROR(Write(request), Succeeded()); + llvm::Expected expected_json = json::parse(response); + ASSERT_THAT_EXPECTED(expected_json, llvm::Succeeded()); + EXPECT_THAT_EXPECTED(Run(), HasValue(testing::ElementsAre(*expected_json))); } TEST_F(ProtocolServerMCPTest, ToolsList) { llvm::StringLiteral request = R"json({"method":"tools/list","params":{},"jsonrpc":"2.0","id":1})json"; - llvm::StringLiteral response = - R"json({"id":1,"jsonrpc":"2.0","result":{"tools":[{"description":"test tool","inputSchema":{"type":"object"},"name":"test"},{"description":"Run an lldb command.","inputSchema":{"properties":{"arguments":{"type":"string"},"debugger_id":{"type":"number"}},"required":["debugger_id"],"type":"object"},"name":"lldb_command"}]}})json"; - - ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); - RunOnce([&](llvm::Expected response_str) { - ASSERT_THAT_EXPECTED(response_str, llvm::Succeeded()); - - llvm::Expected response_json = json::parse(*response_str); - ASSERT_THAT_EXPECTED(response_json, llvm::Succeeded()); - llvm::Expected expected_json = json::parse(response); - ASSERT_THAT_EXPECTED(expected_json, llvm::Succeeded()); + ToolDefinition test_tool; + test_tool.name = "test"; + test_tool.description = "test tool"; + test_tool.inputSchema = json::Object{{"type", "object"}}; + + ToolDefinition lldb_command_tool; + lldb_command_tool.description = "Run an lldb command."; + lldb_command_tool.name = "lldb_command"; + lldb_command_tool.inputSchema = json::Object{ + {"type", "object"}, + {"properties", + json::Object{{"arguments", json::Object{{"type", "string"}}}, + {"debugger_id", json::Object{{"type", "number"}}}}}, + {"required", json::Array{"debugger_id"}}}; + Response response; + response.id = 1; + response.result = json::Object{ + {"tools", + json::Array{std::move(test_tool), std::move(lldb_command_tool)}}, + }; - EXPECT_EQ(*response_json, *expected_json); - }); + ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); + EXPECT_THAT_EXPECTED(Run(), HasValue(testing::ElementsAre(response))); } TEST_F(ProtocolServerMCPTest, ResourcesList) { @@ -246,17 +264,9 @@ TEST_F(ProtocolServerMCPTest, ResourcesList) { R"json({"id":2,"jsonrpc":"2.0","result":{"resources":[{"description":"description","mimeType":"application/json","name":"name","uri":"lldb://foo/bar"}]}})json"; ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); - RunOnce([&](llvm::Expected response_str) { - ASSERT_THAT_EXPECTED(response_str, llvm::Succeeded()); - - llvm::Expected response_json = json::parse(*response_str); - ASSERT_THAT_EXPECTED(response_json, llvm::Succeeded()); - - llvm::Expected expected_json = json::parse(response); - ASSERT_THAT_EXPECTED(expected_json, llvm::Succeeded()); - - EXPECT_EQ(*response_json, *expected_json); - }); + llvm::Expected expected_json = json::parse(response); + ASSERT_THAT_EXPECTED(expected_json, llvm::Succeeded()); + EXPECT_THAT_EXPECTED(Run(), HasValue(testing::ElementsAre(*expected_json))); } TEST_F(ProtocolServerMCPTest, ToolsCall) { @@ -266,17 +276,9 @@ TEST_F(ProtocolServerMCPTest, ToolsCall) { R"json({"id":11,"jsonrpc":"2.0","result":{"content":[{"text":"foo","type":"text"}],"isError":false}})json"; ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); - RunOnce([&](llvm::Expected response_str) { - ASSERT_THAT_EXPECTED(response_str, llvm::Succeeded()); - - llvm::Expected response_json = json::parse(*response_str); - ASSERT_THAT_EXPECTED(response_json, llvm::Succeeded()); - - llvm::Expected expected_json = json::parse(response); - ASSERT_THAT_EXPECTED(expected_json, llvm::Succeeded()); - - EXPECT_EQ(*response_json, *expected_json); - }); + llvm::Expected expected_json = json::parse(response); + ASSERT_THAT_EXPECTED(expected_json, llvm::Succeeded()); + ASSERT_THAT_EXPECTED(Run(), HasValue(testing::ElementsAre(*expected_json))); } TEST_F(ProtocolServerMCPTest, ToolsCallError) { @@ -288,17 +290,9 @@ TEST_F(ProtocolServerMCPTest, ToolsCallError) { R"json({"error":{"code":-32603,"message":"error"},"id":11,"jsonrpc":"2.0"})json"; ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); - RunOnce([&](llvm::Expected response_str) { - ASSERT_THAT_EXPECTED(response_str, llvm::Succeeded()); - - llvm::Expected response_json = json::parse(*response_str); - ASSERT_THAT_EXPECTED(response_json, llvm::Succeeded()); - - llvm::Expected expected_json = json::parse(response); - ASSERT_THAT_EXPECTED(expected_json, llvm::Succeeded()); - - EXPECT_EQ(*response_json, *expected_json); - }); + llvm::Expected expected_json = json::parse(response); + ASSERT_THAT_EXPECTED(expected_json, llvm::Succeeded()); + ASSERT_THAT_EXPECTED(Run(), HasValue(testing::ElementsAre(*expected_json))); } TEST_F(ProtocolServerMCPTest, ToolsCallFail) { @@ -310,17 +304,9 @@ TEST_F(ProtocolServerMCPTest, ToolsCallFail) { R"json({"id":11,"jsonrpc":"2.0","result":{"content":[{"text":"failed","type":"text"}],"isError":true}})json"; ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); - RunOnce([&](llvm::Expected response_str) { - ASSERT_THAT_EXPECTED(response_str, llvm::Succeeded()); - - llvm::Expected response_json = json::parse(*response_str); - ASSERT_THAT_EXPECTED(response_json, llvm::Succeeded()); - - llvm::Expected expected_json = json::parse(response); - ASSERT_THAT_EXPECTED(expected_json, llvm::Succeeded()); - - EXPECT_EQ(*response_json, *expected_json); - }); + llvm::Expected expected_json = json::parse(response); + ASSERT_THAT_EXPECTED(expected_json, llvm::Succeeded()); + ASSERT_THAT_EXPECTED(Run(), HasValue(testing::ElementsAre(*expected_json))); } TEST_F(ProtocolServerMCPTest, NotificationInitialized) { From e4abbc0d59972853f1eb65cc80b8fac9b3bb4b64 Mon Sep 17 00:00:00 2001 From: John Harrison Date: Wed, 13 Aug 2025 16:04:23 -0700 Subject: [PATCH 2/5] Addressing reviewer comments. --- lldb/include/lldb/Host/JSONTransport.h | 27 ++++++++++------ lldb/tools/lldb-dap/DAP.cpp | 44 ++++++++++++++++---------- lldb/tools/lldb-dap/DAP.h | 2 +- lldb/tools/lldb-dap/tool/lldb-dap.cpp | 5 +-- 4 files changed, 48 insertions(+), 30 deletions(-) diff --git a/lldb/include/lldb/Host/JSONTransport.h b/lldb/include/lldb/Host/JSONTransport.h index 18126f599c380..f160599243de7 100644 --- a/lldb/include/lldb/Host/JSONTransport.h +++ b/lldb/include/lldb/Host/JSONTransport.h @@ -139,9 +139,8 @@ class JSONTransport : public Transport { std::string output = Encode(message); size_t bytes_written = output.size(); Status status = m_out->Write(output.data(), bytes_written); - if (status.Fail()) { - this->Logv("writing failed: s{0}", status.AsCString()); - } + if (status.Fail()) + this->Logv("writing failed: {0}", status.AsCString()); } llvm::SmallString m_buffer; @@ -170,8 +169,8 @@ class JSONTransport : public Transport { return; } - for (const auto &raw_message : *raw_messages) { - auto message = + for (const std::string &raw_message : *raw_messages) { + llvm::Expected::Message> message = llvm::json::parse::Message>( raw_message); if (!message) { @@ -182,13 +181,20 @@ class JSONTransport : public Transport { if (Evt *evt = std::get_if(&*message)) { handler.OnEvent(*evt); - } else if (Req *req = std::get_if(&*message)) { + continue; + } + + if (Req *req = std::get_if(&*message)) { handler.OnRequest(*req); - } else if (Resp *resp = std::get_if(&*message)) { + continue; + } + + if (Resp *resp = std::get_if(&*message)) { handler.OnResponse(*resp); - } else { - llvm_unreachable("unknown message type"); + continue; } + + llvm_unreachable("unknown message type"); } } @@ -235,7 +241,8 @@ class HTTPDelimitedJSONTransport : public JSONTransport { auto [headers, rest] = buffer.split(kEndOfHeader); size_t content_length = 0; // HTTP Headers are formatted like ` ':' []`. - for (const auto &header : llvm::split(headers, kHeaderSeparator)) { + for (const llvm::StringRef &header : + llvm::split(headers, kHeaderSeparator)) { auto [key, value] = header.split(kHeaderFieldSeparator); // 'Content-Length' is the only meaningful key at the moment. Others are // ignored. diff --git a/lldb/tools/lldb-dap/DAP.cpp b/lldb/tools/lldb-dap/DAP.cpp index a9a0fe75a35b7..2eced4f78fbd3 100644 --- a/lldb/tools/lldb-dap/DAP.cpp +++ b/lldb/tools/lldb-dap/DAP.cpp @@ -268,24 +268,33 @@ void DAP::SendJSON(const llvm::json::Value &json) { void DAP::Send(const Message &message) { if (const protocol::Event *event = std::get_if(&message)) { transport.Event(*event); - } else if (const Request *req = std::get_if(&message)) { + return; + } + + if (const Request *req = std::get_if(&message)) { transport.Request(*req); - } else if (const Response *resp = std::get_if(&message)) { + return; + } + + if (const Response *resp = std::get_if(&message)) { // FIXME: After all the requests have migrated from LegacyRequestHandler > // RequestHandler<> this should be handled in RequestHandler<>::operator(). - if (debugger.InterruptRequested()) - // If the debugger was interrupted, convert this response into a - // 'cancelled' response because we might have a partial result. + + // If the debugger was interrupted, convert this response into a + // 'cancelled' response because we might have a partial result. + if (debugger.InterruptRequested()) { transport.Response(Response{/*request_seq=*/resp->request_seq, /*command=*/resp->command, /*success=*/false, /*message=*/eResponseMessageCancelled, /*body=*/std::nullopt}); - else + } else { transport.Response(*resp); - } else { - llvm_unreachable("Unexpected message type"); + } + return; } + + llvm_unreachable("Unexpected message type"); } // "OutputEvent": { @@ -916,7 +925,8 @@ llvm::Error DAP::Disconnect(bool terminateDebuggee) { SendTerminatedEvent(); - disconnecting = true; + std::lock_guard guard(m_queue_mutex); + m_disconnecting = true; return ToError(error); } @@ -952,7 +962,7 @@ void DAP::OnEvent(const protocol::Event &event) { void DAP::OnRequest(const protocol::Request &request) { if (request.command == "disconnect") - disconnecting = true; + m_disconnecting = true; const std::optional cancel_args = getArgumentsIfRequest(request, "cancel"); @@ -990,12 +1000,12 @@ void DAP::OnResponse(const protocol::Response &response) { void DAP::TransportHandler(llvm::Error *error) { llvm::ErrorAsOutParameter ErrAsOutParam(*error); - auto cleanup = llvm::make_scope_exit([&]() { - // Ensure we're marked as disconnecting when the reader exits. - disconnecting = true; - m_queue_cv.notify_all(); - }); *error = transport.Run(m_loop, *this); + + std::lock_guard guard(m_queue_mutex); + // Ensure we're marked as disconnecting when the reader exits. + m_disconnecting = true; + m_queue_cv.notify_all(); } llvm::Error DAP::Loop() { @@ -1010,9 +1020,9 @@ llvm::Error DAP::Loop() { while (true) { std::unique_lock lock(m_queue_mutex); - m_queue_cv.wait(lock, [&] { return disconnecting || !m_queue.empty(); }); + m_queue_cv.wait(lock, [&] { return m_disconnecting || !m_queue.empty(); }); - if (disconnecting && m_queue.empty()) + if (m_disconnecting && m_queue.empty()) break; Message next = m_queue.front(); diff --git a/lldb/tools/lldb-dap/DAP.h b/lldb/tools/lldb-dap/DAP.h index 628f97257d5f0..71ec6a7faf7bc 100644 --- a/lldb/tools/lldb-dap/DAP.h +++ b/lldb/tools/lldb-dap/DAP.h @@ -118,7 +118,6 @@ struct DAP final : private DAPTransport::MessageHandler { /// The focused thread for this DAP session. lldb::tid_t focus_tid = LLDB_INVALID_THREAD_ID; - bool disconnecting = false; llvm::once_flag terminated_event_flag; bool stop_at_entry = false; bool is_attach = false; @@ -467,6 +466,7 @@ struct DAP final : private DAPTransport::MessageHandler { std::deque m_queue; std::mutex m_queue_mutex; std::condition_variable m_queue_cv; + bool m_disconnecting = false; // Loop for managing reading from the client. lldb_private::MainLoop &m_loop; diff --git a/lldb/tools/lldb-dap/tool/lldb-dap.cpp b/lldb/tools/lldb-dap/tool/lldb-dap.cpp index c728b0af94c7c..b74085f25f4e2 100644 --- a/lldb/tools/lldb-dap/tool/lldb-dap.cpp +++ b/lldb/tools/lldb-dap/tool/lldb-dap.cpp @@ -39,6 +39,7 @@ #include "llvm/Support/PrettyStackTrace.h" #include "llvm/Support/Signals.h" #include "llvm/Support/Threading.h" +#include "llvm/Support/WithColor.h" #include "llvm/Support/raw_ostream.h" #include #include @@ -349,8 +350,8 @@ serveConnection(const Socket::SocketProtocol &protocol, const std::string &name, for (auto [loop, dap] : dap_sessions) { if (llvm::Error error = dap->Disconnect()) { client_failed = true; - llvm::errs() << "DAP client disconnected failed: " - << llvm::toString(std::move(error)) << "\n"; + llvm::WithColor::error() << "DAP client disconnected failed: " + << llvm::toString(std::move(error)) << "\n"; } loop->AddPendingCallback( [](MainLoopBase &loop) { loop.RequestTermination(); }); From f268900c4f9f9d29d8d94380189fad1cb37e3a08 Mon Sep 17 00:00:00 2001 From: John Harrison Date: Fri, 15 Aug 2025 17:59:35 -0700 Subject: [PATCH 3/5] Refactoring the Transport to not directly run the MainLoop and surface errors to the MessageHandler instead of internally logging and ignoring them. This allows the MessageHandler to determine what to do when an error occurs. --- lldb/include/lldb/Host/JSONTransport.h | 99 +++++++++---------- lldb/tools/lldb-dap/DAP.cpp | 99 +++++++++++++++---- lldb/tools/lldb-dap/DAP.h | 6 +- lldb/tools/lldb-dap/Transport.h | 4 - lldb/unittests/DAP/DAPTest.cpp | 2 +- lldb/unittests/DAP/Handler/DisconnectTest.cpp | 5 +- lldb/unittests/DAP/TestBase.cpp | 41 +++++++- lldb/unittests/DAP/TestBase.h | 36 +++++-- lldb/unittests/Host/JSONTransportTest.cpp | 69 +++++++++---- .../ProtocolServer/ProtocolMCPServerTest.cpp | 15 ++- 10 files changed, 265 insertions(+), 111 deletions(-) diff --git a/lldb/include/lldb/Host/JSONTransport.h b/lldb/include/lldb/Host/JSONTransport.h index f160599243de7..adfcf662e2c75 100644 --- a/lldb/include/lldb/Host/JSONTransport.h +++ b/lldb/include/lldb/Host/JSONTransport.h @@ -62,25 +62,43 @@ template class Transport { virtual ~Transport() = default; - // Called by transport to send outgoing messages. - virtual void Event(const Evt &) = 0; - virtual void Request(const Req &) = 0; - virtual void Response(const Resp &) = 0; + /// Sends an event, a message that does not require a response. + virtual llvm::Error Event(const Evt &) = 0; + /// Sends a request, a message that expects a response. + virtual llvm::Error Request(const Req &) = 0; + /// Sends a response to a specific request. + virtual llvm::Error Response(const Resp &) = 0; /// Implemented to handle incoming messages. (See Run() below). class MessageHandler { public: virtual ~MessageHandler() = default; + /// Called when an event is received. virtual void OnEvent(const Evt &) = 0; + /// Called when a request is received. virtual void OnRequest(const Req &) = 0; + /// Called when a response is received. virtual void OnResponse(const Resp &) = 0; + + /// Called when an error occurs while reading from the transport. + /// + /// NOTE: This does *NOT* indicate that a specific request failed, but that + /// there was an error in the underlying transport. + virtual void OnError(MainLoopBase &, llvm::Error) = 0; + + /// Called on EOF or disconnect. + virtual void OnEOF() = 0; }; - /// Called by server or client to receive messages from the connection. - /// The transport should in turn invoke the handler to process messages. - /// The MainLoop is used to handle reading from the incoming connection and - /// will run until the loop is terminated. - virtual llvm::Error Run(MainLoop &, MessageHandler &) = 0; + using MessageHandlerSP = std::shared_ptr; + + /// RegisterMessageHandler registers the Transport with the given MainLoop and + /// handles any incoming messages using the given MessageHandler. + /// + /// If an unexpected error occurs, the MainLoop will be terminated and a log + /// message will include additional information about the termination reason. + virtual llvm::Expected + RegisterMessageHandler(MainLoop &loop, MessageHandler &handler) = 0; protected: template inline auto Logv(const char *Fmt, Ts &&...Vals) { @@ -94,37 +112,27 @@ template class JSONTransport : public Transport { public: using Transport::Transport; + using MessageHandler = typename Transport::MessageHandler; JSONTransport(lldb::IOObjectSP in, lldb::IOObjectSP out) : m_in(in), m_out(out) {} - void Event(const Evt &evt) override { Write(evt); } - void Request(const Req &req) override { Write(req); } - void Response(const Resp &resp) override { Write(resp); } + llvm::Error Event(const Evt &evt) override { return Write(evt); } + llvm::Error Request(const Req &req) override { return Write(req); } + llvm::Error Response(const Resp &resp) override { return Write(resp); } - /// Run registers the transport with the given MainLoop and handles any - /// incoming messages using the given MessageHandler. - llvm::Error - Run(MainLoop &loop, - typename Transport::MessageHandler &handler) override { - llvm::Error error = llvm::Error::success(); + llvm::Expected + RegisterMessageHandler(MainLoop &loop, MessageHandler &handler) override { Status status; - auto read_handle = loop.RegisterReadObject( + MainLoop::ReadHandleUP read_handle = loop.RegisterReadObject( m_in, - std::bind(&JSONTransport::OnRead, this, &error, std::placeholders::_1, + std::bind(&JSONTransport::OnRead, this, std::placeholders::_1, std::ref(handler)), status); if (status.Fail()) { - // This error is only set if the read object handler is invoked, mark it - // as consumed if registration of the handler failed. - llvm::consumeError(std::move(error)); return status.takeError(); } - - status = loop.Run(); - if (status.Fail()) - return status.takeError(); - return error; + return read_handle; } /// Public for testing purposes, otherwise this should be an implementation @@ -134,26 +142,21 @@ class JSONTransport : public Transport { protected: virtual llvm::Expected> Parse() = 0; virtual std::string Encode(const llvm::json::Value &message) = 0; - void Write(const llvm::json::Value &message) { + llvm::Error Write(const llvm::json::Value &message) { this->Logv("<-- {0}", message); std::string output = Encode(message); size_t bytes_written = output.size(); - Status status = m_out->Write(output.data(), bytes_written); - if (status.Fail()) - this->Logv("writing failed: {0}", status.AsCString()); + return m_out->Write(output.data(), bytes_written).takeError(); } llvm::SmallString m_buffer; private: - void OnRead(llvm::Error *err, MainLoopBase &loop, - typename Transport::MessageHandler &handler) { - llvm::ErrorAsOutParameter ErrAsOutParam(err); + void OnRead(MainLoopBase &loop, MessageHandler &handler) { char buf[kReadBufferSize]; size_t num_bytes = sizeof(buf); if (Status status = m_in->Read(buf, num_bytes); status.Fail()) { - *err = status.takeError(); - loop.RequestTermination(); + handler.OnError(loop, status.takeError()); return; } @@ -164,8 +167,7 @@ class JSONTransport : public Transport { if (!m_buffer.empty()) { llvm::Expected> raw_messages = Parse(); if (llvm::Error error = raw_messages.takeError()) { - *err = std::move(error); - loop.RequestTermination(); + handler.OnError(loop, std::move(error)); return; } @@ -174,9 +176,8 @@ class JSONTransport : public Transport { llvm::json::parse::Message>( raw_message); if (!message) { - *err = message.takeError(); - loop.RequestTermination(); - return; + handler.OnError(loop, message.takeError()); + continue; } if (Evt *evt = std::get_if(&*message)) { @@ -198,15 +199,13 @@ class JSONTransport : public Transport { } } + // Check if we reached EOF. if (num_bytes == 0) { - // If we're at EOF and we have unhandled contents in the buffer, return an - // error for the partial message. - if (m_buffer.empty()) - *err = llvm::Error::success(); - else - *err = llvm::make_error( - std::string(m_buffer)); - loop.RequestTermination(); + // EOF reached, but there may still be unhandled contents in the buffer. + if (!m_buffer.empty()) + handler.OnError(loop, llvm::make_error( + std::string(m_buffer.str()))); + handler.OnEOF(); } } diff --git a/lldb/tools/lldb-dap/DAP.cpp b/lldb/tools/lldb-dap/DAP.cpp index 2eced4f78fbd3..a003d1bbaa15b 100644 --- a/lldb/tools/lldb-dap/DAP.cpp +++ b/lldb/tools/lldb-dap/DAP.cpp @@ -267,29 +267,40 @@ void DAP::SendJSON(const llvm::json::Value &json) { void DAP::Send(const Message &message) { if (const protocol::Event *event = std::get_if(&message)) { - transport.Event(*event); + if (llvm::Error err = transport.Event(*event)) { + DAP_LOG_ERROR(log, std::move(err), "({0}) sending event failed", + m_client_name); + return; + } return; } if (const Request *req = std::get_if(&message)) { - transport.Request(*req); + if (llvm::Error err = transport.Request(*req)) { + DAP_LOG_ERROR(log, std::move(err), "({0}) sending request failed", + m_client_name); + return; + } return; } if (const Response *resp = std::get_if(&message)) { // FIXME: After all the requests have migrated from LegacyRequestHandler > // RequestHandler<> this should be handled in RequestHandler<>::operator(). - // If the debugger was interrupted, convert this response into a // 'cancelled' response because we might have a partial result. - if (debugger.InterruptRequested()) { - transport.Response(Response{/*request_seq=*/resp->request_seq, + llvm::Error err = + (debugger.InterruptRequested()) + ? transport.Response({/*request_seq=*/resp->request_seq, /*command=*/resp->command, /*success=*/false, /*message=*/eResponseMessageCancelled, - /*body=*/std::nullopt}); - } else { - transport.Response(*resp); + /*body=*/std::nullopt}) + : transport.Response(*resp); + if (err) { + DAP_LOG_ERROR(log, std::move(err), "({0}) sending response failed", + m_client_name); + return; } return; } @@ -924,10 +935,7 @@ llvm::Error DAP::Disconnect(bool terminateDebuggee) { } SendTerminatedEvent(); - - std::lock_guard guard(m_queue_mutex); - m_disconnecting = true; - + TerminateLoop(); return ToError(error); } @@ -998,21 +1006,66 @@ void DAP::OnResponse(const protocol::Response &response) { m_queue_cv.notify_one(); } -void DAP::TransportHandler(llvm::Error *error) { - llvm::ErrorAsOutParameter ErrAsOutParam(*error); - *error = transport.Run(m_loop, *this); +void DAP::OnError(MainLoopBase &loop, llvm::Error error) { + DAP_LOG_ERROR(log, std::move(error), "({1}) received error: {0}", + m_client_name); + TerminateLoop(/*failed=*/true); +} + +void DAP::OnEOF() { + DAP_LOG(log, "({0}) received EOF", m_client_name); + TerminateLoop(); +} +void DAP::TerminateLoop(bool failed) { std::lock_guard guard(m_queue_mutex); - // Ensure we're marked as disconnecting when the reader exits. + if (m_disconnecting) + return; // Already disconnecting. + + m_error_occurred = failed; m_disconnecting = true; - m_queue_cv.notify_all(); + m_loop.AddPendingCallback( + [](MainLoopBase &loop) { loop.RequestTermination(); }); +} + +void DAP::TransportHandler() { + auto scope_guard = llvm::make_scope_exit([this] { + std::lock_guard guard(m_queue_mutex); + // Ensure we're marked as disconnecting when the reader exits. + m_disconnecting = true; + m_queue_cv.notify_all(); + }); + + auto handle = transport.RegisterMessageHandler(m_loop, *this); + if (!handle) { + DAP_LOG_ERROR(log, handle.takeError(), + "({1}) registering message handler failed: {0}", + m_client_name); + std::lock_guard guard(m_queue_mutex); + m_error_occurred = true; + return; + } + + if (Status status = m_loop.Run(); status.Fail()) { + DAP_LOG_ERROR(log, status.takeError(), "({1}) MainLoop run failed: {0}", + m_client_name); + std::lock_guard guard(m_queue_mutex); + m_error_occurred = true; + return; + } } llvm::Error DAP::Loop() { - llvm::Error error = llvm::Error::success(); - auto thread = std::thread(std::bind(&DAP::TransportHandler, this, &error)); + { + // Reset disconnect flag once we start the loop. + std::lock_guard guard(m_queue_mutex); + m_disconnecting = false; + } - auto cleanup = llvm::make_scope_exit([&]() { + auto thread = std::thread(std::bind(&DAP::TransportHandler, this)); + + auto cleanup = llvm::make_scope_exit([this]() { + // FIXME: Merge these into the MainLoop handler. out.Stop(); err.Stop(); StopEventHandlers(); @@ -1040,7 +1093,11 @@ llvm::Error DAP::Loop() { [](MainLoopBase &loop) { loop.RequestTermination(); }); thread.join(); - return error; + if (m_error_occurred) + return llvm::createStringError(llvm::inconvertibleErrorCode(), + "DAP Loop terminated due to an internal " + "error, see DAP Logs for more information."); + return llvm::Error::success(); } lldb::SBError DAP::WaitForProcessToStop(std::chrono::seconds seconds) { diff --git a/lldb/tools/lldb-dap/DAP.h b/lldb/tools/lldb-dap/DAP.h index 71ec6a7faf7bc..02aebb2c3d23c 100644 --- a/lldb/tools/lldb-dap/DAP.h +++ b/lldb/tools/lldb-dap/DAP.h @@ -429,6 +429,8 @@ struct DAP final : private DAPTransport::MessageHandler { void OnEvent(const protocol::Event &) override; void OnRequest(const protocol::Request &) override; void OnResponse(const protocol::Response &) override; + void OnError(lldb_private::MainLoopBase &loop, llvm::Error error) override; + void OnEOF() override; private: std::vector SetSourceBreakpoints( @@ -436,7 +438,8 @@ struct DAP final : private DAPTransport::MessageHandler { const std::optional> &breakpoints, SourceBreakpointMap &existing_breakpoints); - void TransportHandler(llvm::Error *); + void TransportHandler(); + void TerminateLoop(bool failed = false); /// Registration of request handler. /// @{ @@ -467,6 +470,7 @@ struct DAP final : private DAPTransport::MessageHandler { std::mutex m_queue_mutex; std::condition_variable m_queue_cv; bool m_disconnecting = false; + bool m_error_occurred = false; // Loop for managing reading from the client. lldb_private::MainLoop &m_loop; diff --git a/lldb/tools/lldb-dap/Transport.h b/lldb/tools/lldb-dap/Transport.h index efeb0b9cd6c55..4a9dd76c2303e 100644 --- a/lldb/tools/lldb-dap/Transport.h +++ b/lldb/tools/lldb-dap/Transport.h @@ -34,10 +34,6 @@ class Transport final void Log(llvm::StringRef message) override; - /// Returns the name of this transport client, for example `stdin/stdout` or - /// `client_1`. - llvm::StringRef GetClientName() { return m_client_name; } - private: llvm::StringRef m_client_name; lldb_dap::Log *m_log; diff --git a/lldb/unittests/DAP/DAPTest.cpp b/lldb/unittests/DAP/DAPTest.cpp index 744e6e69a8d33..e6486e710086f 100644 --- a/lldb/unittests/DAP/DAPTest.cpp +++ b/lldb/unittests/DAP/DAPTest.cpp @@ -28,7 +28,7 @@ TEST_F(DAPTest, SendProtocolMessages) { /*pre_init_commands=*/{}, /*client_name=*/"test_client", /*transport=*/*transport, - loop, + /*loop=*/loop, }; dap.Send(Event{/*event=*/"my-event", /*body=*/std::nullopt}); loop.AddPendingCallback( diff --git a/lldb/unittests/DAP/Handler/DisconnectTest.cpp b/lldb/unittests/DAP/Handler/DisconnectTest.cpp index 5b082151680dd..01edeb1d08b31 100644 --- a/lldb/unittests/DAP/Handler/DisconnectTest.cpp +++ b/lldb/unittests/DAP/Handler/DisconnectTest.cpp @@ -28,9 +28,7 @@ class DisconnectRequestHandlerTest : public DAPTestBase {}; TEST_F(DisconnectRequestHandlerTest, DisconnectTriggersTerminated) { DisconnectRequestHandler handler(*dap); - EXPECT_FALSE(dap->disconnecting); ASSERT_THAT_ERROR(handler.Run(std::nullopt), Succeeded()); - EXPECT_TRUE(dap->disconnecting); RunOnce(); EXPECT_THAT(from_dap, testing::Contains(testing::VariantWith(testing::FieldsAre( @@ -47,17 +45,16 @@ TEST_F(DisconnectRequestHandlerTest, DisconnectTriggersTerminateCommands) { DisconnectRequestHandler handler(*dap); - EXPECT_FALSE(dap->disconnecting); dap->configuration.terminateCommands = {"?script print(1)", "script print(2)"}; EXPECT_EQ(dap->target.GetProcess().GetState(), lldb::eStateStopped); ASSERT_THAT_ERROR(handler.Run(std::nullopt), Succeeded()); - EXPECT_TRUE(dap->disconnecting); RunOnce(); EXPECT_THAT(from_dap, testing::Contains(OutputMatcher("Running terminateCommands:\n"))); EXPECT_THAT(from_dap, testing::Contains(OutputMatcher("(lldb) script print(2)\n"))); + EXPECT_THAT(from_dap, testing::Contains(OutputMatcher("1\n"))); EXPECT_THAT(from_dap, testing::Contains(OutputMatcher("2\n"))); EXPECT_THAT(from_dap, testing::Contains(testing::VariantWith(testing::FieldsAre( diff --git a/lldb/unittests/DAP/TestBase.cpp b/lldb/unittests/DAP/TestBase.cpp index 64097d177c4a9..e1a7059f345a1 100644 --- a/lldb/unittests/DAP/TestBase.cpp +++ b/lldb/unittests/DAP/TestBase.cpp @@ -16,6 +16,7 @@ #include "llvm/Support/Error.h" #include "llvm/Testing/Support/Error.h" #include "gtest/gtest.h" +#include #include using namespace llvm; @@ -23,9 +24,27 @@ using namespace lldb; using namespace lldb_dap; using namespace lldb_dap::protocol; using namespace lldb_dap_tests; +using lldb_private::File; +using lldb_private::FileSpec; +using lldb_private::FileSystem; using lldb_private::MainLoop; using lldb_private::Pipe; +Expected +TestTransport::RegisterMessageHandler(MainLoop &loop, MessageHandler &handler) { + Expected dummy_file = FileSystem::Instance().Open( + FileSpec(FileSystem::DEV_NULL), File::eOpenOptionReadWrite); + if (!dummy_file) + return dummy_file.takeError(); + m_dummy_file = std::move(*dummy_file); + lldb_private::Status status; + auto handle = loop.RegisterReadObject( + m_dummy_file, [](lldb_private::MainLoopBase &) {}, status); + if (status.Fail()) + return status.takeError(); + return handle; +} + void DAPTestBase::SetUp() { TransportBase::SetUp(); dap = std::make_unique( @@ -51,7 +70,7 @@ void DAPTestBase::SetUpTestSuite() { } void DAPTestBase::TeatUpTestSuite() { SBDebugger::Terminate(); } -bool DAPTestBase::GetDebuggerSupportsTarget(llvm::StringRef platform) { +bool DAPTestBase::GetDebuggerSupportsTarget(StringRef platform) { EXPECT_TRUE(dap->debugger); lldb::SBStructuredData data = dap->debugger.GetBuildConfiguration() @@ -60,7 +79,7 @@ bool DAPTestBase::GetDebuggerSupportsTarget(llvm::StringRef platform) { for (size_t i = 0; i < data.GetSize(); i++) { char buf[100] = {0}; size_t size = data.GetItemAtIndex(i).GetStringValue(buf, sizeof(buf)); - if (llvm::StringRef(buf, size) == platform) + if (StringRef(buf, size) == platform) return true; } @@ -70,6 +89,24 @@ bool DAPTestBase::GetDebuggerSupportsTarget(llvm::StringRef platform) { void DAPTestBase::CreateDebugger() { dap->debugger = lldb::SBDebugger::Create(); ASSERT_TRUE(dap->debugger); + dap->target = dap->debugger.GetDummyTarget(); + + Expected dev_null = FileSystem::Instance().Open( + FileSpec(FileSystem::DEV_NULL), File::eOpenOptionReadWrite); + ASSERT_THAT_EXPECTED(dev_null, Succeeded()); + lldb::FileSP dev_null_sp = std::move(*dev_null); + + std::FILE *dev_null_stream = dev_null_sp->GetStream(); + ASSERT_THAT_ERROR(dap->ConfigureIO(dev_null_stream, dev_null_stream), + Succeeded()); + + dap->debugger.SetInputFile(dap->in); + auto out_fd = dap->out.GetWriteFileDescriptor(); + ASSERT_THAT_EXPECTED(out_fd, Succeeded()); + dap->debugger.SetOutputFile(lldb::SBFile(*out_fd, "w", false)); + auto err_fd = dap->out.GetWriteFileDescriptor(); + ASSERT_THAT_EXPECTED(err_fd, Succeeded()); + dap->debugger.SetErrorFile(lldb::SBFile(*err_fd, "w", false)); } void DAPTestBase::LoadCore() { diff --git a/lldb/unittests/DAP/TestBase.h b/lldb/unittests/DAP/TestBase.h index 4591c0fc72726..594e1e0a6bbb7 100644 --- a/lldb/unittests/DAP/TestBase.h +++ b/lldb/unittests/DAP/TestBase.h @@ -8,12 +8,19 @@ #include "DAP.h" #include "Protocol/ProtocolBase.h" +#include "lldb/Host/File.h" +#include "lldb/Host/FileSystem.h" #include "lldb/Host/MainLoop.h" #include "lldb/Host/MainLoopBase.h" +#include "lldb/Utility/FileSpec.h" +#include "lldb/lldb-forward.h" #include "llvm/ADT/StringRef.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/FileSystem.h" #include "llvm/Testing/Support/Error.h" #include "gmock/gmock.h" #include "gtest/gtest.h" +#include namespace lldb_dap_tests { @@ -29,27 +36,30 @@ class TestTransport final TestTransport(lldb_private::MainLoop &loop, MessageHandler &handler) : m_loop(loop), m_handler(handler) {} - void Event(const lldb_dap::protocol::Event &e) override { + llvm::Error Event(const lldb_dap::protocol::Event &e) override { m_loop.AddPendingCallback([this, e](lldb_private::MainLoopBase &) { this->m_handler.OnEvent(e); }); + return llvm::Error::success(); } - void Request(const lldb_dap::protocol::Request &r) override { + llvm::Error Request(const lldb_dap::protocol::Request &r) override { m_loop.AddPendingCallback([this, r](lldb_private::MainLoopBase &) { this->m_handler.OnRequest(r); }); + return llvm::Error::success(); } - void Response(const lldb_dap::protocol::Response &r) override { + llvm::Error Response(const lldb_dap::protocol::Response &r) override { m_loop.AddPendingCallback([this, r](lldb_private::MainLoopBase &) { this->m_handler.OnResponse(r); }); + return llvm::Error::success(); } - llvm::Error Run(lldb_private::MainLoop &loop, MessageHandler &) override { - return loop.Run().takeError(); - } + llvm::Expected + RegisterMessageHandler(lldb_private::MainLoop &loop, + MessageHandler &handler) override; void Log(llvm::StringRef message) override { log_messages.emplace_back(message); @@ -60,6 +70,7 @@ class TestTransport final private: lldb_private::MainLoop &m_loop; MessageHandler &m_handler; + lldb::FileSP m_dummy_file; }; /// A base class for tests that need transport configured for communicating DAP @@ -78,12 +89,22 @@ class TransportBase : public testing::Test, void OnEvent(const lldb_dap::protocol::Event &e) override { from_dap.emplace_back(e); } + void OnRequest(const lldb_dap::protocol::Request &r) override { from_dap.emplace_back(r); } + void OnResponse(const lldb_dap::protocol::Response &r) override { from_dap.emplace_back(r); } + + void OnError(lldb_private::MainLoopBase &loop, llvm::Error error) override { + loop.RequestTermination(); + FAIL() << "Error while reading from transport: " + << llvm::toString(std::move(error)); + } + + void OnEOF() override { /* no-op */ } }; /// Matches an "output" event. @@ -113,9 +134,6 @@ class DAPTestBase : public TransportBase { void CreateDebugger(); void LoadCore(); - /// Closes the DAP output pipe and returns the remaining protocol messages in - /// the buffer. - // std::vector DrainOutput(); void RunOnce() { loop.AddPendingCallback( [](lldb_private::MainLoopBase &loop) { loop.RequestTermination(); }); diff --git a/lldb/unittests/Host/JSONTransportTest.cpp b/lldb/unittests/Host/JSONTransportTest.cpp index fdfde328c69b7..233c0cc078698 100644 --- a/lldb/unittests/Host/JSONTransportTest.cpp +++ b/lldb/unittests/Host/JSONTransportTest.cpp @@ -160,20 +160,45 @@ class JSONTransportTest : public PipePairTest { class MessageCollector final : public Transport::MessageHandler { public: + MessageCollector(llvm::Error *err = nullptr) : err(err) { + if (err) + consumeError(std::move(*err)); + } std::vector messages; + llvm::Error *err; void OnEvent(const Evt &V) override { messages.emplace_back(V); } void OnRequest(const Req &V) override { messages.emplace_back(V); } void OnResponse(const Resp &V) override { messages.emplace_back(V); } + void OnError(MainLoopBase &loop, llvm::Error error) override { + loop.RequestTermination(); + if (err) + *err = std::move(error); + else + FAIL() << "Error while reading from transport: " + << llvm::toString(std::move(error)); + } + void OnEOF() override { /* no-op */ } }; - /// Run the transport MainLoop and return any messages received. Expected> Run(std::chrono::milliseconds timeout = std::chrono::milliseconds(5000)) { - MessageCollector collector; + return Run(nullptr, timeout); + } + + /// Run the transport MainLoop and return any messages received. + Expected> + Run(llvm::Error *err, + std::chrono::milliseconds timeout = std::chrono::milliseconds(5000)) { + MessageCollector collector(err); loop.AddCallback([](MainLoopBase &loop) { loop.RequestTermination(); }, timeout); - if (auto error = transport->Run(loop, collector)) - return error; + auto handle = transport->RegisterMessageHandler(loop, collector); + if (!handle) + return handle.takeError(); + + if (Status status = loop.Run(); status.Fail()) + return status.takeError(); + return std::move(collector.messages); } @@ -263,7 +288,10 @@ TEST_F(HTTPDelimitedJSONTransportTest, MalformedRequests) { ASSERT_THAT_EXPECTED( input.Write(malformed_header.data(), malformed_header.size()), Succeeded()); - ASSERT_THAT_EXPECTED(Run(), FailedWithMessage("invalid content length: -1")); + llvm::Error err = llvm::Error::success(); + ASSERT_THAT_EXPECTED(Run(&err), Succeeded()); + ASSERT_THAT_ERROR(std::move(err), + FailedWithMessage("invalid content length: -1")); } TEST_F(HTTPDelimitedJSONTransportTest, Read) { @@ -291,8 +319,9 @@ TEST_F(HTTPDelimitedJSONTransportTest, ReadPartialMessage) { std::string part2 = message.substr(split_at); ASSERT_THAT_EXPECTED(input.Write(part1.data(), part1.size()), Succeeded()); - ASSERT_THAT_EXPECTED(Run(/*timeout=*/std::chrono::milliseconds(10)), - HasValue(testing::IsEmpty())); + ASSERT_THAT_EXPECTED( + Run(/*err=*/nullptr, /*timeout=*/std::chrono::milliseconds(10)), + HasValue(testing::IsEmpty())); ASSERT_THAT_EXPECTED(input.Write(part2.data(), part2.size()), Succeeded()); input.CloseWriteFileDescriptor(); @@ -334,7 +363,9 @@ TEST_F(HTTPDelimitedJSONTransportTest, ReaderWithUnhandledData) { ASSERT_THAT_EXPECTED(input.Write(message.data(), message.size() - 1), Succeeded()); input.CloseWriteFileDescriptor(); - ASSERT_THAT_EXPECTED(Run(), Failed()); + Error err = Error::success(); + ASSERT_THAT_EXPECTED(Run(&err), Succeeded()); + ASSERT_THAT_ERROR(std::move(err), Failed()); } TEST_F(HTTPDelimitedJSONTransportTest, NoDataTimeout) { @@ -349,9 +380,9 @@ TEST_F(HTTPDelimitedJSONTransportTest, InvalidTransport) { } TEST_F(HTTPDelimitedJSONTransportTest, Write) { - transport->Request(Req{"foo"}); - transport->Response(Resp{"bar"}); - transport->Event(Evt{"baz"}); + ASSERT_THAT_ERROR(transport->Request(Req{"foo"}), Succeeded()); + ASSERT_THAT_ERROR(transport->Response(Resp{"bar"}), Succeeded()); + ASSERT_THAT_ERROR(transport->Event(Evt{"baz"}), Succeeded()); output.CloseWriteFileDescriptor(); char buf[1024]; Expected bytes_read = @@ -370,8 +401,10 @@ TEST_F(JSONRPCTransportTest, MalformedRequests) { ASSERT_THAT_EXPECTED( input.Write(malformed_header.data(), malformed_header.size()), Succeeded()); - ASSERT_THAT_EXPECTED( - Run(), FailedWithMessage("[1:2, byte=2]: Invalid JSON value (null?)")); + Error err = Error::success(); + ASSERT_THAT_EXPECTED(Run(&err), Succeeded()); + ASSERT_THAT_ERROR(std::move(err), FailedWithMessage(testing::HasSubstr( + "Invalid JSON value"))); } TEST_F(JSONRPCTransportTest, Read) { @@ -416,13 +449,15 @@ TEST_F(JSONRPCTransportTest, ReaderWithUnhandledData) { ASSERT_THAT_EXPECTED(input.Write(message.data(), message.size()), Succeeded()); input.CloseWriteFileDescriptor(); - EXPECT_THAT_EXPECTED(Run(), Failed()); + Error err = Error::success(); + EXPECT_THAT_EXPECTED(Run(&err), Succeeded()); + ASSERT_THAT_ERROR(std::move(err), Failed()); } TEST_F(JSONRPCTransportTest, Write) { - transport->Request(Req{"foo"}); - transport->Response(Resp{"bar"}); - transport->Event(Evt{"baz"}); + ASSERT_THAT_ERROR(transport->Request(Req{"foo"}), Succeeded()); + ASSERT_THAT_ERROR(transport->Response(Resp{"bar"}), Succeeded()); + ASSERT_THAT_ERROR(transport->Event(Evt{"baz"}), Succeeded()); output.CloseWriteFileDescriptor(); char buf[1024]; Expected bytes_read = diff --git a/lldb/unittests/ProtocolServer/ProtocolMCPServerTest.cpp b/lldb/unittests/ProtocolServer/ProtocolMCPServerTest.cpp index 588093edf321a..d10ecfdba2738 100644 --- a/lldb/unittests/ProtocolServer/ProtocolMCPServerTest.cpp +++ b/lldb/unittests/ProtocolServer/ProtocolMCPServerTest.cpp @@ -163,6 +163,12 @@ class ProtocolServerMCPTest : public ::testing::Test { void OnEvent(const Notification &V) override { messages.emplace_back(V); } void OnRequest(const Request &V) override { messages.emplace_back(V); } void OnResponse(const Response &V) override { messages.emplace_back(V); } + void OnError(MainLoopBase &loop, llvm::Error error) override { + loop.RequestTermination(); + FAIL() << "Error while reading from transport: " + << llvm::toString(std::move(error)); + } + void OnEOF() override { /* no-op */ } }; /// Run the transport MainLoop and return any messages received. @@ -171,8 +177,13 @@ class ProtocolServerMCPTest : public ::testing::Test { MessageCollector collector; loop.AddCallback([](MainLoopBase &loop) { loop.RequestTermination(); }, timeout); - if (auto error = m_transport_up->Run(loop, collector)) - return error; + auto handle = m_transport_up->RegisterMessageHandler(loop, collector); + if (!handle) + return handle.takeError(); + + if (Status status = loop.Run(); status.Fail()) + return status.takeError(); + return std::move(collector.messages); } From 5546f66b606c74c57c93f1ec31024057586f59b0 Mon Sep 17 00:00:00 2001 From: John Harrison Date: Mon, 18 Aug 2025 15:32:24 -0700 Subject: [PATCH 4/5] Using gmock for testing message handling. Also refactored the Transport and MessageHandler classes to have more uniform names, which allows us to use `std::visit` for dispatching to the correct method. --- lldb/include/lldb/Host/JSONTransport.h | 60 ++--- lldb/tools/lldb-dap/DAP.cpp | 30 +-- lldb/tools/lldb-dap/DAP.h | 10 +- lldb/tools/lldb-dap/Protocol/ProtocolBase.cpp | 13 + lldb/unittests/DAP/DAPTest.cpp | 6 +- lldb/unittests/DAP/Handler/DisconnectTest.cpp | 19 +- lldb/unittests/DAP/TestBase.cpp | 6 +- lldb/unittests/DAP/TestBase.h | 64 ++--- lldb/unittests/Host/JSONTransportTest.cpp | 244 +++++++++--------- .../ProtocolServer/ProtocolMCPServerTest.cpp | 74 +++--- .../Host/JSONTransportTestUtilities.h | 26 ++ 11 files changed, 274 insertions(+), 278 deletions(-) create mode 100644 lldb/unittests/TestingSupport/Host/JSONTransportTestUtilities.h diff --git a/lldb/include/lldb/Host/JSONTransport.h b/lldb/include/lldb/Host/JSONTransport.h index adfcf662e2c75..0be60a8f3f96a 100644 --- a/lldb/include/lldb/Host/JSONTransport.h +++ b/lldb/include/lldb/Host/JSONTransport.h @@ -63,31 +63,31 @@ template class Transport { virtual ~Transport() = default; /// Sends an event, a message that does not require a response. - virtual llvm::Error Event(const Evt &) = 0; + virtual llvm::Error Send(const Evt &) = 0; /// Sends a request, a message that expects a response. - virtual llvm::Error Request(const Req &) = 0; + virtual llvm::Error Send(const Req &) = 0; /// Sends a response to a specific request. - virtual llvm::Error Response(const Resp &) = 0; + virtual llvm::Error Send(const Resp &) = 0; /// Implemented to handle incoming messages. (See Run() below). class MessageHandler { public: virtual ~MessageHandler() = default; /// Called when an event is received. - virtual void OnEvent(const Evt &) = 0; + virtual void Received(const Evt &) = 0; /// Called when a request is received. - virtual void OnRequest(const Req &) = 0; + virtual void Received(const Req &) = 0; /// Called when a response is received. - virtual void OnResponse(const Resp &) = 0; + virtual void Received(const Resp &) = 0; /// Called when an error occurs while reading from the transport. /// /// NOTE: This does *NOT* indicate that a specific request failed, but that /// there was an error in the underlying transport. - virtual void OnError(MainLoopBase &, llvm::Error) = 0; + virtual void OnError(llvm::Error) = 0; - /// Called on EOF or disconnect. - virtual void OnEOF() = 0; + /// Called on EOF or client disconnect. + virtual void OnClosed() = 0; }; using MessageHandlerSP = std::shared_ptr; @@ -117,9 +117,9 @@ class JSONTransport : public Transport { JSONTransport(lldb::IOObjectSP in, lldb::IOObjectSP out) : m_in(in), m_out(out) {} - llvm::Error Event(const Evt &evt) override { return Write(evt); } - llvm::Error Request(const Req &req) override { return Write(req); } - llvm::Error Response(const Resp &resp) override { return Write(resp); } + llvm::Error Send(const Evt &evt) override { return Write(evt); } + llvm::Error Send(const Req &req) override { return Write(req); } + llvm::Error Send(const Resp &resp) override { return Write(resp); } llvm::Expected RegisterMessageHandler(MainLoop &loop, MessageHandler &handler) override { @@ -156,7 +156,7 @@ class JSONTransport : public Transport { char buf[kReadBufferSize]; size_t num_bytes = sizeof(buf); if (Status status = m_in->Read(buf, num_bytes); status.Fail()) { - handler.OnError(loop, status.takeError()); + handler.OnError(status.takeError()); return; } @@ -167,7 +167,7 @@ class JSONTransport : public Transport { if (!m_buffer.empty()) { llvm::Expected> raw_messages = Parse(); if (llvm::Error error = raw_messages.takeError()) { - handler.OnError(loop, std::move(error)); + handler.OnError(std::move(error)); return; } @@ -176,26 +176,11 @@ class JSONTransport : public Transport { llvm::json::parse::Message>( raw_message); if (!message) { - handler.OnError(loop, message.takeError()); - continue; - } - - if (Evt *evt = std::get_if(&*message)) { - handler.OnEvent(*evt); - continue; + handler.OnError(message.takeError()); + return; } - if (Req *req = std::get_if(&*message)) { - handler.OnRequest(*req); - continue; - } - - if (Resp *resp = std::get_if(&*message)) { - handler.OnResponse(*resp); - continue; - } - - llvm_unreachable("unknown message type"); + std::visit([&handler](auto &&msg) { handler.Received(msg); }, *message); } } @@ -203,9 +188,9 @@ class JSONTransport : public Transport { if (num_bytes == 0) { // EOF reached, but there may still be unhandled contents in the buffer. if (!m_buffer.empty()) - handler.OnError(loop, llvm::make_error( - std::string(m_buffer.str()))); - handler.OnEOF(); + handler.OnError(llvm::make_error( + std::string(m_buffer.str()))); + handler.OnClosed(); } } @@ -249,10 +234,13 @@ class HTTPDelimitedJSONTransport : public JSONTransport { continue; value = value.trim(); - if (!llvm::to_integer(value, content_length, 10)) + if (!llvm::to_integer(value, content_length, 10)) { + // Clear the buffer to avoid re-parsing this malformed message. + this->m_buffer.clear(); return llvm::createStringError(std::errc::invalid_argument, "invalid content length: %s", value.str().c_str()); + } } // Check if we have enough data. diff --git a/lldb/tools/lldb-dap/DAP.cpp b/lldb/tools/lldb-dap/DAP.cpp index a003d1bbaa15b..e51ed096073fe 100644 --- a/lldb/tools/lldb-dap/DAP.cpp +++ b/lldb/tools/lldb-dap/DAP.cpp @@ -267,20 +267,16 @@ void DAP::SendJSON(const llvm::json::Value &json) { void DAP::Send(const Message &message) { if (const protocol::Event *event = std::get_if(&message)) { - if (llvm::Error err = transport.Event(*event)) { + if (llvm::Error err = transport.Send(*event)) DAP_LOG_ERROR(log, std::move(err), "({0}) sending event failed", m_client_name); - return; - } return; } if (const Request *req = std::get_if(&message)) { - if (llvm::Error err = transport.Request(*req)) { + if (llvm::Error err = transport.Send(*req)) DAP_LOG_ERROR(log, std::move(err), "({0}) sending request failed", m_client_name); - return; - } return; } @@ -291,12 +287,12 @@ void DAP::Send(const Message &message) { // 'cancelled' response because we might have a partial result. llvm::Error err = (debugger.InterruptRequested()) - ? transport.Response({/*request_seq=*/resp->request_seq, - /*command=*/resp->command, - /*success=*/false, - /*message=*/eResponseMessageCancelled, - /*body=*/std::nullopt}) - : transport.Response(*resp); + ? transport.Send({/*request_seq=*/resp->request_seq, + /*command=*/resp->command, + /*success=*/false, + /*message=*/eResponseMessageCancelled, + /*body=*/std::nullopt}) + : transport.Send(*resp); if (err) { DAP_LOG_ERROR(log, std::move(err), "({0}) sending response failed", m_client_name); @@ -964,11 +960,11 @@ static std::optional getArgumentsIfRequest(const Request &req, return args; } -void DAP::OnEvent(const protocol::Event &event) { +void DAP::Received(const protocol::Event &event) { // no-op, no supported events from the client to the server as of DAP v1.68. } -void DAP::OnRequest(const protocol::Request &request) { +void DAP::Received(const protocol::Request &request) { if (request.command == "disconnect") m_disconnecting = true; @@ -998,7 +994,7 @@ void DAP::OnRequest(const protocol::Request &request) { m_queue_cv.notify_one(); } -void DAP::OnResponse(const protocol::Response &response) { +void DAP::Received(const protocol::Response &response) { std::lock_guard guard(m_queue_mutex); DAP_LOG(log, "({0}) queued (command={1} seq={2})", m_client_name, response.command, response.request_seq); @@ -1006,13 +1002,13 @@ void DAP::OnResponse(const protocol::Response &response) { m_queue_cv.notify_one(); } -void DAP::OnError(MainLoopBase &loop, llvm::Error error) { +void DAP::OnError(llvm::Error error) { DAP_LOG_ERROR(log, std::move(error), "({1}) received error: {0}", m_client_name); TerminateLoop(/*failed=*/true); } -void DAP::OnEOF() { +void DAP::OnClosed() { DAP_LOG(log, "({0}) received EOF", m_client_name); TerminateLoop(); } diff --git a/lldb/tools/lldb-dap/DAP.h b/lldb/tools/lldb-dap/DAP.h index 02aebb2c3d23c..0b6373fb80381 100644 --- a/lldb/tools/lldb-dap/DAP.h +++ b/lldb/tools/lldb-dap/DAP.h @@ -426,11 +426,11 @@ struct DAP final : private DAPTransport::MessageHandler { const std::optional> &breakpoints); - void OnEvent(const protocol::Event &) override; - void OnRequest(const protocol::Request &) override; - void OnResponse(const protocol::Response &) override; - void OnError(lldb_private::MainLoopBase &loop, llvm::Error error) override; - void OnEOF() override; + void Received(const protocol::Event &) override; + void Received(const protocol::Request &) override; + void Received(const protocol::Response &) override; + void OnError(llvm::Error) override; + void OnClosed() override; private: std::vector SetSourceBreakpoints( diff --git a/lldb/tools/lldb-dap/Protocol/ProtocolBase.cpp b/lldb/tools/lldb-dap/Protocol/ProtocolBase.cpp index bc4fee4aa8b8d..9cd9028d879e9 100644 --- a/lldb/tools/lldb-dap/Protocol/ProtocolBase.cpp +++ b/lldb/tools/lldb-dap/Protocol/ProtocolBase.cpp @@ -98,6 +98,10 @@ bool fromJSON(json::Value const &Params, Request &R, json::Path P) { return mapRaw(Params, "arguments", R.arguments, P); } +bool operator==(const Request &a, const Request &b) { + return a.seq == b.seq && a.command == b.command && a.arguments == b.arguments; +} + json::Value toJSON(const Response &R) { json::Object Result{{"type", "response"}, {"seq", 0}, @@ -177,6 +181,11 @@ bool fromJSON(json::Value const &Params, Response &R, json::Path P) { mapRaw(Params, "body", R.body, P); } +bool operator==(const Response &a, const Response &b) { + return a.request_seq == b.request_seq && a.command == b.command && + a.success == b.success && a.message == b.message && a.body == b.body; +} + json::Value toJSON(const ErrorMessage &EM) { json::Object Result{{"id", EM.id}, {"format", EM.format}}; @@ -248,6 +257,10 @@ bool fromJSON(json::Value const &Params, Event &E, json::Path P) { return mapRaw(Params, "body", E.body, P); } +bool operator==(const Event &a, const Event &b) { + return a.event == b.event && a.body == b.body; +} + bool fromJSON(const json::Value &Params, Message &PM, json::Path P) { json::ObjectMapper O(Params, P); if (!O) diff --git a/lldb/unittests/DAP/DAPTest.cpp b/lldb/unittests/DAP/DAPTest.cpp index e6486e710086f..d5a9591ad0a43 100644 --- a/lldb/unittests/DAP/DAPTest.cpp +++ b/lldb/unittests/DAP/DAPTest.cpp @@ -10,6 +10,7 @@ #include "Protocol/ProtocolBase.h" #include "TestBase.h" #include "llvm/Testing/Support/Error.h" +#include "gmock/gmock.h" #include "gtest/gtest.h" #include @@ -18,6 +19,7 @@ using namespace lldb; using namespace lldb_dap; using namespace lldb_dap_tests; using namespace lldb_dap::protocol; +using namespace testing; class DAPTest : public TransportBase {}; @@ -33,8 +35,6 @@ TEST_F(DAPTest, SendProtocolMessages) { dap.Send(Event{/*event=*/"my-event", /*body=*/std::nullopt}); loop.AddPendingCallback( [](lldb_private::MainLoopBase &loop) { loop.RequestTermination(); }); + EXPECT_CALL(client, Received(IsEvent("my-event", std::nullopt))); ASSERT_THAT_ERROR(dap.Loop(), llvm::Succeeded()); - ASSERT_THAT(from_dap, - ElementsAre(testing::VariantWith(testing::FieldsAre( - /*event=*/"my-event", /*body=*/std::nullopt)))); } diff --git a/lldb/unittests/DAP/Handler/DisconnectTest.cpp b/lldb/unittests/DAP/Handler/DisconnectTest.cpp index 01edeb1d08b31..c6ff1f90b01d5 100644 --- a/lldb/unittests/DAP/Handler/DisconnectTest.cpp +++ b/lldb/unittests/DAP/Handler/DisconnectTest.cpp @@ -23,16 +23,15 @@ using namespace lldb; using namespace lldb_dap; using namespace lldb_dap_tests; using namespace lldb_dap::protocol; +using testing::_; class DisconnectRequestHandlerTest : public DAPTestBase {}; TEST_F(DisconnectRequestHandlerTest, DisconnectTriggersTerminated) { DisconnectRequestHandler handler(*dap); ASSERT_THAT_ERROR(handler.Run(std::nullopt), Succeeded()); + EXPECT_CALL(client, Received(IsEvent("terminated", _))); RunOnce(); - EXPECT_THAT(from_dap, - testing::Contains(testing::VariantWith(testing::FieldsAre( - /*event=*/"terminated", /*body=*/testing::_)))); } TEST_F(DisconnectRequestHandlerTest, DisconnectTriggersTerminateCommands) { @@ -49,14 +48,10 @@ TEST_F(DisconnectRequestHandlerTest, DisconnectTriggersTerminateCommands) { "script print(2)"}; EXPECT_EQ(dap->target.GetProcess().GetState(), lldb::eStateStopped); ASSERT_THAT_ERROR(handler.Run(std::nullopt), Succeeded()); + EXPECT_CALL(client, Received(Output("1\n"))); + EXPECT_CALL(client, Received(Output("2\n"))).Times(2); + EXPECT_CALL(client, Received(Output("(lldb) script print(2)\n"))); + EXPECT_CALL(client, Received(Output("Running terminateCommands:\n"))); + EXPECT_CALL(client, Received(IsEvent("terminated", _))); RunOnce(); - EXPECT_THAT(from_dap, - testing::Contains(OutputMatcher("Running terminateCommands:\n"))); - EXPECT_THAT(from_dap, - testing::Contains(OutputMatcher("(lldb) script print(2)\n"))); - EXPECT_THAT(from_dap, testing::Contains(OutputMatcher("1\n"))); - EXPECT_THAT(from_dap, testing::Contains(OutputMatcher("2\n"))); - EXPECT_THAT(from_dap, - testing::Contains(testing::VariantWith(testing::FieldsAre( - /*event=*/"terminated", /*body=*/testing::_)))); } diff --git a/lldb/unittests/DAP/TestBase.cpp b/lldb/unittests/DAP/TestBase.cpp index e1a7059f345a1..54ac27da694e6 100644 --- a/lldb/unittests/DAP/TestBase.cpp +++ b/lldb/unittests/DAP/TestBase.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "TestBase.h" +#include "DAPLog.h" #include "TestingSupport/TestUtilities.h" #include "lldb/API/SBDefines.h" #include "lldb/API/SBStructuredData.h" @@ -18,6 +19,7 @@ #include "gtest/gtest.h" #include #include +#include using namespace llvm; using namespace lldb; @@ -47,8 +49,10 @@ TestTransport::RegisterMessageHandler(MainLoop &loop, MessageHandler &handler) { void DAPTestBase::SetUp() { TransportBase::SetUp(); + std::error_code EC; + log = std::make_unique("-", EC); dap = std::make_unique( - /*log=*/nullptr, + /*log=*/log.get(), /*default_repl_mode=*/ReplMode::Auto, /*pre_init_commands=*/std::vector(), /*client_name=*/"test_client", diff --git a/lldb/unittests/DAP/TestBase.h b/lldb/unittests/DAP/TestBase.h index 594e1e0a6bbb7..0d70a2bc04082 100644 --- a/lldb/unittests/DAP/TestBase.h +++ b/lldb/unittests/DAP/TestBase.h @@ -8,15 +8,14 @@ #include "DAP.h" #include "Protocol/ProtocolBase.h" -#include "lldb/Host/File.h" -#include "lldb/Host/FileSystem.h" +#include "TestingSupport/Host/JSONTransportTestUtilities.h" #include "lldb/Host/MainLoop.h" #include "lldb/Host/MainLoopBase.h" -#include "lldb/Utility/FileSpec.h" #include "lldb/lldb-forward.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Error.h" #include "llvm/Support/FileSystem.h" +#include "llvm/Support/JSON.h" #include "llvm/Testing/Support/Error.h" #include "gmock/gmock.h" #include "gtest/gtest.h" @@ -36,23 +35,23 @@ class TestTransport final TestTransport(lldb_private::MainLoop &loop, MessageHandler &handler) : m_loop(loop), m_handler(handler) {} - llvm::Error Event(const lldb_dap::protocol::Event &e) override { + llvm::Error Send(const lldb_dap::protocol::Event &e) override { m_loop.AddPendingCallback([this, e](lldb_private::MainLoopBase &) { - this->m_handler.OnEvent(e); + this->m_handler.Received(e); }); return llvm::Error::success(); } - llvm::Error Request(const lldb_dap::protocol::Request &r) override { + llvm::Error Send(const lldb_dap::protocol::Request &r) override { m_loop.AddPendingCallback([this, r](lldb_private::MainLoopBase &) { - this->m_handler.OnRequest(r); + this->m_handler.Received(r); }); return llvm::Error::success(); } - llvm::Error Response(const lldb_dap::protocol::Response &r) override { + llvm::Error Send(const lldb_dap::protocol::Response &r) override { m_loop.AddPendingCallback([this, r](lldb_private::MainLoopBase &) { - this->m_handler.OnResponse(r); + this->m_handler.Received(r); }); return llvm::Error::success(); } @@ -75,49 +74,38 @@ class TestTransport final /// A base class for tests that need transport configured for communicating DAP /// messages. -class TransportBase : public testing::Test, - public TestTransport::MessageHandler { +class TransportBase : public testing::Test { protected: - std::vector from_dap; lldb_private::MainLoop loop; std::unique_ptr transport; + MockMessageHandler + client; void SetUp() override { - transport = std::make_unique(loop, *this); + transport = std::make_unique(loop, client); } - - void OnEvent(const lldb_dap::protocol::Event &e) override { - from_dap.emplace_back(e); - } - - void OnRequest(const lldb_dap::protocol::Request &r) override { - from_dap.emplace_back(r); - } - - void OnResponse(const lldb_dap::protocol::Response &r) override { - from_dap.emplace_back(r); - } - - void OnError(lldb_private::MainLoopBase &loop, llvm::Error error) override { - loop.RequestTermination(); - FAIL() << "Error while reading from transport: " - << llvm::toString(std::move(error)); - } - - void OnEOF() override { /* no-op */ } }; +/// A matcher for a DAP event. +template +inline testing::Matcher +IsEvent(const M1 &m1, const M2 &m2) { + return testing::AllOf(testing::Field(&lldb_dap::protocol::Event::event, m1), + testing::Field(&lldb_dap::protocol::Event::body, m2)); +} + /// Matches an "output" event. -inline auto OutputMatcher(const llvm::StringRef output, - const llvm::StringRef category = "console") { - return testing::VariantWith(testing::FieldsAre( - /*event=*/"output", /*body=*/testing::Optional( - llvm::json::Object{{"category", category}, {"output", output}}))); +inline auto Output(llvm::StringRef o, llvm::StringRef cat = "console") { + return IsEvent("output", + testing::Optional(llvm::json::Value( + llvm::json::Object{{"category", cat}, {"output", o}}))); } /// A base class for tests that interact with a `lldb_dap::DAP` instance. class DAPTestBase : public TransportBase { protected: + std::unique_ptr log; std::unique_ptr dap; std::optional core; std::optional binary; diff --git a/lldb/unittests/Host/JSONTransportTest.cpp b/lldb/unittests/Host/JSONTransportTest.cpp index 233c0cc078698..445674f402252 100644 --- a/lldb/unittests/Host/JSONTransportTest.cpp +++ b/lldb/unittests/Host/JSONTransportTest.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "lldb/Host/JSONTransport.h" +#include "TestingSupport/Host/JSONTransportTestUtilities.h" #include "TestingSupport/Host/PipeTestUtilities.h" #include "lldb/Host/File.h" #include "lldb/Host/MainLoop.h" @@ -28,6 +29,9 @@ using namespace llvm; using namespace lldb_private; +using testing::_; +using testing::HasSubstr; +using testing::InSequence; namespace { @@ -94,16 +98,10 @@ void PrintTo(const Evt &message, std::ostream *os) { } using Message = std::variant; -json::Value toJSON(const Message &T) { - if (const Req *req = std::get_if(&T)) - return toJSON(*req); - if (const Resp *resp = std::get_if(&T)) - return toJSON(*resp); - if (const Evt *evt = std::get_if(&T)) - return toJSON(*evt); - llvm_unreachable("unknown message type"); +json::Value toJSON(const Message &msg) { + return std::visit([](const auto &msg) { return toJSON(msg); }, msg); } -bool fromJSON(const json::Value &V, Message &T, json::Path P) { +bool fromJSON(const json::Value &V, Message &msg, json::Path P) { const json::Object *O = V.getAsObject(); if (!O) { P.report("expected object"); @@ -114,7 +112,7 @@ bool fromJSON(const json::Value &V, Message &T, json::Path P) { if (!fromJSON(V, R, P)) return false; - T = std::move(R); + msg = std::move(R); return true; } if (O->get("resp")) { @@ -122,7 +120,7 @@ bool fromJSON(const json::Value &V, Message &T, json::Path P) { if (!fromJSON(V, R, P)) return false; - T = std::move(R); + msg = std::move(R); return true; } if (O->get("evt")) { @@ -130,7 +128,7 @@ bool fromJSON(const json::Value &V, Message &T, json::Path P) { if (!fromJSON(V, E, P)) return false; - T = std::move(E); + msg = std::move(E); return true; } P.report("unknown message type"); @@ -143,6 +141,7 @@ template class JSONTransportTest : public PipePairTest { protected: + MockMessageHandler message_handler; std::unique_ptr transport; MainLoop loop; @@ -157,49 +156,27 @@ class JSONTransportTest : public PipePairTest { NativeFile::Unowned)); } - class MessageCollector final - : public Transport::MessageHandler { - public: - MessageCollector(llvm::Error *err = nullptr) : err(err) { - if (err) - consumeError(std::move(*err)); - } - std::vector messages; - llvm::Error *err; - void OnEvent(const Evt &V) override { messages.emplace_back(V); } - void OnRequest(const Req &V) override { messages.emplace_back(V); } - void OnResponse(const Resp &V) override { messages.emplace_back(V); } - void OnError(MainLoopBase &loop, llvm::Error error) override { - loop.RequestTermination(); - if (err) - *err = std::move(error); - else - FAIL() << "Error while reading from transport: " - << llvm::toString(std::move(error)); - } - void OnEOF() override { /* no-op */ } - }; - - Expected> - Run(std::chrono::milliseconds timeout = std::chrono::milliseconds(5000)) { - return Run(nullptr, timeout); - } - /// Run the transport MainLoop and return any messages received. - Expected> - Run(llvm::Error *err, + Error + Run(bool close_input = true, std::chrono::milliseconds timeout = std::chrono::milliseconds(5000)) { - MessageCollector collector(err); - loop.AddCallback([](MainLoopBase &loop) { loop.RequestTermination(); }, - timeout); - auto handle = transport->RegisterMessageHandler(loop, collector); + if (close_input) { + input.CloseWriteFileDescriptor(); + EXPECT_CALL(message_handler, OnClosed()).WillOnce([this]() { + loop.RequestTermination(); + }); + } + loop.AddCallback( + [](MainLoopBase &loop) { + loop.RequestTermination(); + FAIL() << "timeout"; + }, + timeout); + auto handle = transport->RegisterMessageHandler(loop, message_handler); if (!handle) return handle.takeError(); - if (Status status = loop.Run(); status.Fail()) - return status.takeError(); - - return std::move(collector.messages); + return loop.Run().takeError(); } template void Write(Ts... args) { @@ -210,11 +187,6 @@ class JSONTransportTest : public PipePairTest { Succeeded()); } - template void WriteAndCloseInput(Ts... args) { - Write(std::forward(args)...); - input.CloseWriteFileDescriptor(); - } - virtual std::string Encode(const json::Value &) = 0; }; @@ -288,28 +260,35 @@ TEST_F(HTTPDelimitedJSONTransportTest, MalformedRequests) { ASSERT_THAT_EXPECTED( input.Write(malformed_header.data(), malformed_header.size()), Succeeded()); - llvm::Error err = llvm::Error::success(); - ASSERT_THAT_EXPECTED(Run(&err), Succeeded()); - ASSERT_THAT_ERROR(std::move(err), - FailedWithMessage("invalid content length: -1")); + + EXPECT_CALL(message_handler, OnError(_)).WillOnce([](llvm::Error err) { + ASSERT_THAT_ERROR(std::move(err), + FailedWithMessage("invalid content length: -1")); + }); + ASSERT_THAT_ERROR(Run(), Succeeded()); } TEST_F(HTTPDelimitedJSONTransportTest, Read) { - WriteAndCloseInput(Req{"foo"}); - ASSERT_THAT_EXPECTED(Run(), HasValue(testing::ElementsAre(Req{"foo"}))); + Write(Req{"foo"}); + EXPECT_CALL(message_handler, Received(Req{"foo"})); + ASSERT_THAT_ERROR(Run(), Succeeded()); } TEST_F(HTTPDelimitedJSONTransportTest, ReadMultipleMessagesInSingleWrite) { - WriteAndCloseInput(Message{Req{"one"}}, Message{Resp{"two"}}, - Message{Evt{"three"}}); - EXPECT_THAT_EXPECTED(Run(), HasValue(testing::ElementsAre( - Req{"one"}, Resp{"two"}, Evt{"three"}))); + InSequence seq; + Write(Message{Req{"one"}}, Message{Evt{"two"}}, Message{Resp{"three"}}); + EXPECT_CALL(message_handler, Received(Req{"one"})); + EXPECT_CALL(message_handler, Received(Evt{"two"})); + EXPECT_CALL(message_handler, Received(Resp{"three"})); + ASSERT_THAT_ERROR(Run(), Succeeded()); } TEST_F(HTTPDelimitedJSONTransportTest, ReadAcrossMultipleChunks) { - std::string long_str = std::string(2048, 'x'); - WriteAndCloseInput(Req{long_str}); - ASSERT_THAT_EXPECTED(Run(), HasValue(testing::ElementsAre(Req{long_str}))); + std::string long_str = std::string( + HTTPDelimitedJSONTransport::kReadBufferSize * 2, 'x'); + Write(Req{long_str}); + EXPECT_CALL(message_handler, Received(Req{long_str})); + ASSERT_THAT_ERROR(Run(), Succeeded()); } TEST_F(HTTPDelimitedJSONTransportTest, ReadPartialMessage) { @@ -318,14 +297,15 @@ TEST_F(HTTPDelimitedJSONTransportTest, ReadPartialMessage) { std::string part1 = message.substr(0, split_at); std::string part2 = message.substr(split_at); - ASSERT_THAT_EXPECTED(input.Write(part1.data(), part1.size()), Succeeded()); - ASSERT_THAT_EXPECTED( - Run(/*err=*/nullptr, /*timeout=*/std::chrono::milliseconds(10)), - HasValue(testing::IsEmpty())); + EXPECT_CALL(message_handler, Received(Req{"foo"})); + ASSERT_THAT_EXPECTED(input.Write(part1.data(), part1.size()), Succeeded()); + loop.AddPendingCallback( + [](MainLoopBase &loop) { loop.RequestTermination(); }); + ASSERT_THAT_ERROR(Run(/*close_stdin=*/false), Succeeded()); ASSERT_THAT_EXPECTED(input.Write(part2.data(), part2.size()), Succeeded()); input.CloseWriteFileDescriptor(); - ASSERT_THAT_EXPECTED(Run(), HasValue(testing::ElementsAre(Req{"foo"}))); + ASSERT_THAT_ERROR(Run(), Succeeded()); } TEST_F(HTTPDelimitedJSONTransportTest, ReadWithZeroByteWrites) { @@ -334,23 +314,29 @@ TEST_F(HTTPDelimitedJSONTransportTest, ReadWithZeroByteWrites) { std::string part1 = message.substr(0, split_at); std::string part2 = message.substr(split_at); + EXPECT_CALL(message_handler, Received(Req{"foo"})); + ASSERT_THAT_EXPECTED(input.Write(part1.data(), part1.size()), Succeeded()); - ASSERT_THAT_EXPECTED(Run(/*timeout=*/std::chrono::milliseconds(10)), - HasValue(testing::IsEmpty())); + // Run the main loop once for the initial read. + loop.AddPendingCallback( + [](MainLoopBase &loop) { loop.RequestTermination(); }); + ASSERT_THAT_ERROR(Run(/*close_stdin=*/false), Succeeded()); + + // zero-byte write. ASSERT_THAT_EXPECTED(input.Write(part1.data(), 0), Succeeded()); // zero-byte write. - ASSERT_THAT_EXPECTED(Run(/*timeout=*/std::chrono::milliseconds(10)), - HasValue(testing::IsEmpty())); + loop.AddPendingCallback( + [](MainLoopBase &loop) { loop.RequestTermination(); }); + ASSERT_THAT_ERROR(Run(/*close_stdin=*/false), Succeeded()); + // Write the remaining part of the message. ASSERT_THAT_EXPECTED(input.Write(part2.data(), part2.size()), Succeeded()); - input.CloseWriteFileDescriptor(); - ASSERT_THAT_EXPECTED(Run(), HasValue(testing::ElementsAre(Req{"foo"}))); + ASSERT_THAT_ERROR(Run(), Succeeded()); } TEST_F(HTTPDelimitedJSONTransportTest, ReadWithEOF) { - input.CloseWriteFileDescriptor(); - ASSERT_THAT_EXPECTED(Run(), HasValue(testing::IsEmpty())); + ASSERT_THAT_ERROR(Run(), Succeeded()); } TEST_F(HTTPDelimitedJSONTransportTest, ReaderWithUnhandledData) { @@ -359,30 +345,30 @@ TEST_F(HTTPDelimitedJSONTransportTest, ReaderWithUnhandledData) { formatv("Content-Length: {0}\r\nContent-type: text/json\r\n\r\n{1}", json.size(), json) .str(); + + EXPECT_CALL(message_handler, OnError(_)).WillOnce([](llvm::Error err) { + // The error should indicate that there are unhandled contents. + ASSERT_THAT_ERROR(std::move(err), + Failed()); + }); + // Write an incomplete message and close the handle. ASSERT_THAT_EXPECTED(input.Write(message.data(), message.size() - 1), Succeeded()); - input.CloseWriteFileDescriptor(); - Error err = Error::success(); - ASSERT_THAT_EXPECTED(Run(&err), Succeeded()); - ASSERT_THAT_ERROR(std::move(err), Failed()); -} - -TEST_F(HTTPDelimitedJSONTransportTest, NoDataTimeout) { - ASSERT_THAT_EXPECTED(Run(/*timeout=*/std::chrono::milliseconds(10)), - HasValue(testing::IsEmpty())); + ASSERT_THAT_ERROR(Run(), Succeeded()); } TEST_F(HTTPDelimitedJSONTransportTest, InvalidTransport) { transport = std::make_unique(nullptr, nullptr); - ASSERT_THAT_EXPECTED(Run(), FailedWithMessage("IO object is not valid.")); + ASSERT_THAT_ERROR(Run(/*close_input=*/false), + FailedWithMessage("IO object is not valid.")); } TEST_F(HTTPDelimitedJSONTransportTest, Write) { - ASSERT_THAT_ERROR(transport->Request(Req{"foo"}), Succeeded()); - ASSERT_THAT_ERROR(transport->Response(Resp{"bar"}), Succeeded()); - ASSERT_THAT_ERROR(transport->Event(Evt{"baz"}), Succeeded()); + ASSERT_THAT_ERROR(transport->Send(Req{"foo"}), Succeeded()); + ASSERT_THAT_ERROR(transport->Send(Resp{"bar"}), Succeeded()); + ASSERT_THAT_ERROR(transport->Send(Evt{"baz"}), Succeeded()); output.CloseWriteFileDescriptor(); char buf[1024]; Expected bytes_read = @@ -401,26 +387,36 @@ TEST_F(JSONRPCTransportTest, MalformedRequests) { ASSERT_THAT_EXPECTED( input.Write(malformed_header.data(), malformed_header.size()), Succeeded()); - Error err = Error::success(); - ASSERT_THAT_EXPECTED(Run(&err), Succeeded()); - ASSERT_THAT_ERROR(std::move(err), FailedWithMessage(testing::HasSubstr( - "Invalid JSON value"))); + EXPECT_CALL(message_handler, OnError(_)).WillOnce([](llvm::Error err) { + ASSERT_THAT_ERROR(std::move(err), + FailedWithMessage(HasSubstr("Invalid JSON value"))); + }); + ASSERT_THAT_ERROR(Run(), Succeeded()); } TEST_F(JSONRPCTransportTest, Read) { - WriteAndCloseInput(Message{Req{"foo"}}, Message{Resp{"bar"}}, - Message{Evt{"baz"}}); - ASSERT_THAT_EXPECTED(Run(), HasValue(testing::ElementsAre( - Req{"foo"}, Resp{"bar"}, Evt{"baz"}))); + Write(Message{Req{"foo"}}); + EXPECT_CALL(message_handler, Received(Req{"foo"})); + ASSERT_THAT_ERROR(Run(), Succeeded()); +} + +TEST_F(JSONRPCTransportTest, ReadMultipleMessagesInSingleWrite) { + InSequence seq; + Write(Message{Req{"one"}}, Message{Evt{"two"}}, Message{Resp{"three"}}); + EXPECT_CALL(message_handler, Received(Req{"one"})); + EXPECT_CALL(message_handler, Received(Evt{"two"})); + EXPECT_CALL(message_handler, Received(Resp{"three"})); + ASSERT_THAT_ERROR(Run(), Succeeded()); } TEST_F(JSONRPCTransportTest, ReadAcrossMultipleChunks) { // Use a string longer than the chunk size to ensure we split the message // across the chunk boundary. std::string long_str = - std::string(JSONTransport::kReadBufferSize + 10, 'x'); - WriteAndCloseInput(Req{long_str}); - ASSERT_THAT_EXPECTED(Run(), HasValue(testing::ElementsAre(Req{long_str}))); + std::string(JSONTransport::kReadBufferSize * 2, 'x'); + Write(Req{long_str}); + EXPECT_CALL(message_handler, Received(Req{long_str})); + ASSERT_THAT_ERROR(Run(), Succeeded()); } TEST_F(JSONRPCTransportTest, ReadPartialMessage) { @@ -429,35 +425,39 @@ TEST_F(JSONRPCTransportTest, ReadPartialMessage) { std::string part1 = message.substr(0, 7); std::string part2 = message.substr(7); + EXPECT_CALL(message_handler, Received(Req{"foo"})); + ASSERT_THAT_EXPECTED(input.Write(part1.data(), part1.size()), Succeeded()); - ASSERT_THAT_EXPECTED(Run(std::chrono::milliseconds(10)), - HasValue(testing::IsEmpty())); + loop.AddPendingCallback( + [](MainLoopBase &loop) { loop.RequestTermination(); }); + ASSERT_THAT_ERROR(Run(/*close_input=*/false), Succeeded()); ASSERT_THAT_EXPECTED(input.Write(part2.data(), part2.size()), Succeeded()); input.CloseWriteFileDescriptor(); - ASSERT_THAT_EXPECTED(Run(), HasValue(testing::ElementsAre(Req{"foo"}))); + ASSERT_THAT_ERROR(Run(), Succeeded()); } TEST_F(JSONRPCTransportTest, ReadWithEOF) { - input.CloseWriteFileDescriptor(); - ASSERT_THAT_EXPECTED(Run(), HasValue(testing::IsEmpty())); + ASSERT_THAT_ERROR(Run(), Succeeded()); } TEST_F(JSONRPCTransportTest, ReaderWithUnhandledData) { std::string message = R"json({"req": "foo")json"; // Write an incomplete message and close the handle. - ASSERT_THAT_EXPECTED(input.Write(message.data(), message.size()), + ASSERT_THAT_EXPECTED(input.Write(message.data(), message.size() - 1), Succeeded()); - input.CloseWriteFileDescriptor(); - Error err = Error::success(); - EXPECT_THAT_EXPECTED(Run(&err), Succeeded()); - ASSERT_THAT_ERROR(std::move(err), Failed()); + + EXPECT_CALL(message_handler, OnError(_)).WillOnce([](llvm::Error err) { + ASSERT_THAT_ERROR(std::move(err), + Failed()); + }); + ASSERT_THAT_ERROR(Run(), Succeeded()); } TEST_F(JSONRPCTransportTest, Write) { - ASSERT_THAT_ERROR(transport->Request(Req{"foo"}), Succeeded()); - ASSERT_THAT_ERROR(transport->Response(Resp{"bar"}), Succeeded()); - ASSERT_THAT_ERROR(transport->Event(Evt{"baz"}), Succeeded()); + ASSERT_THAT_ERROR(transport->Send(Req{"foo"}), Succeeded()); + ASSERT_THAT_ERROR(transport->Send(Resp{"bar"}), Succeeded()); + ASSERT_THAT_ERROR(transport->Send(Evt{"baz"}), Succeeded()); output.CloseWriteFileDescriptor(); char buf[1024]; Expected bytes_read = @@ -473,12 +473,8 @@ TEST_F(JSONRPCTransportTest, Write) { TEST_F(JSONRPCTransportTest, InvalidTransport) { transport = std::make_unique(nullptr, nullptr); - ASSERT_THAT_EXPECTED(Run(), FailedWithMessage("IO object is not valid.")); -} - -TEST_F(JSONRPCTransportTest, NoDataTimeout) { - ASSERT_THAT_EXPECTED(Run(/*timeout=*/std::chrono::milliseconds(10)), - HasValue(testing::ElementsAre())); + ASSERT_THAT_ERROR(Run(/*close_input=*/false), + FailedWithMessage("IO object is not valid.")); } #endif diff --git a/lldb/unittests/ProtocolServer/ProtocolMCPServerTest.cpp b/lldb/unittests/ProtocolServer/ProtocolMCPServerTest.cpp index d10ecfdba2738..8c3b87f0fdc35 100644 --- a/lldb/unittests/ProtocolServer/ProtocolMCPServerTest.cpp +++ b/lldb/unittests/ProtocolServer/ProtocolMCPServerTest.cpp @@ -8,6 +8,7 @@ #include "Plugins/Platform/MacOSX/PlatformRemoteMacOSX.h" #include "Plugins/Protocol/MCP/ProtocolServerMCP.h" +#include "TestingSupport/Host/JSONTransportTestUtilities.h" #include "TestingSupport/SubsystemRAII.h" #include "lldb/Core/Debugger.h" #include "lldb/Core/ProtocolServer.h" @@ -33,6 +34,7 @@ using namespace llvm; using namespace lldb; using namespace lldb_private; using namespace lldb_protocol::mcp; +using testing::_; namespace { class TestProtocolServerMCP : public lldb_private::mcp::ProtocolServerMCP { @@ -143,6 +145,7 @@ class ProtocolServerMCPTest : public ::testing::Test { std::unique_ptr m_transport_up; std::unique_ptr m_server_up; MainLoop loop; + MockMessageHandler message_handler; static constexpr llvm::StringLiteral k_localhost = "localhost"; @@ -156,35 +159,16 @@ class ProtocolServerMCPTest : public ::testing::Test { EXPECT_THAT_ERROR(m_io_sp->Close().takeError(), Succeeded()); } - class MessageCollector final - : public Transport::MessageHandler { - public: - std::vector messages; - void OnEvent(const Notification &V) override { messages.emplace_back(V); } - void OnRequest(const Request &V) override { messages.emplace_back(V); } - void OnResponse(const Response &V) override { messages.emplace_back(V); } - void OnError(MainLoopBase &loop, llvm::Error error) override { - loop.RequestTermination(); - FAIL() << "Error while reading from transport: " - << llvm::toString(std::move(error)); - } - void OnEOF() override { /* no-op */ } - }; - /// Run the transport MainLoop and return any messages received. - Expected> + llvm::Error Run(std::chrono::milliseconds timeout = std::chrono::milliseconds(100)) { - MessageCollector collector; loop.AddCallback([](MainLoopBase &loop) { loop.RequestTermination(); }, timeout); - auto handle = m_transport_up->RegisterMessageHandler(loop, collector); + auto handle = m_transport_up->RegisterMessageHandler(loop, message_handler); if (!handle) return handle.takeError(); - if (Status status = loop.Run(); status.Fail()) - return status.takeError(); - - return std::move(collector.messages); + return loop.Run().takeError(); } void SetUp() override { @@ -229,19 +213,20 @@ class ProtocolServerMCPTest : public ::testing::Test { TEST_F(ProtocolServerMCPTest, Initialization) { llvm::StringLiteral request = - R"json({"method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"lldb-unit","version":"0.1.0"}},"jsonrpc":"2.0","id":0})json"; + R"json({"method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"lldb-unit","version":"0.1.0"}},"jsonrpc":"2.0","id":1})json"; llvm::StringLiteral response = - R"json( {"id":0,"jsonrpc":"2.0","result":{"capabilities":{"resources":{"listChanged":false,"subscribe":false},"tools":{"listChanged":true}},"protocolVersion":"2024-11-05","serverInfo":{"name":"lldb-mcp","version":"0.1.0"}}})json"; + R"json({"id":1,"jsonrpc":"2.0","result":{"capabilities":{"resources":{"listChanged":false,"subscribe":false},"tools":{"listChanged":true}},"protocolVersion":"2024-11-05","serverInfo":{"name":"lldb-mcp","version":"0.1.0"}}})json"; ASSERT_THAT_ERROR(Write(request), Succeeded()); - llvm::Expected expected_json = json::parse(response); - ASSERT_THAT_EXPECTED(expected_json, llvm::Succeeded()); - EXPECT_THAT_EXPECTED(Run(), HasValue(testing::ElementsAre(*expected_json))); + llvm::Expected expected_resp = json::parse(response); + ASSERT_THAT_EXPECTED(expected_resp, llvm::Succeeded()); + EXPECT_CALL(message_handler, Received(*expected_resp)); + EXPECT_THAT_ERROR(Run(), Succeeded()); } TEST_F(ProtocolServerMCPTest, ToolsList) { llvm::StringLiteral request = - R"json({"method":"tools/list","params":{},"jsonrpc":"2.0","id":1})json"; + R"json({"method":"tools/list","params":{},"jsonrpc":"2.0","id":"one"})json"; ToolDefinition test_tool; test_tool.name = "test"; @@ -258,14 +243,15 @@ TEST_F(ProtocolServerMCPTest, ToolsList) { {"debugger_id", json::Object{{"type", "number"}}}}}, {"required", json::Array{"debugger_id"}}}; Response response; - response.id = 1; + response.id = "one"; response.result = json::Object{ {"tools", json::Array{std::move(test_tool), std::move(lldb_command_tool)}}, }; ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); - EXPECT_THAT_EXPECTED(Run(), HasValue(testing::ElementsAre(response))); + EXPECT_CALL(message_handler, Received(response)); + EXPECT_THAT_ERROR(Run(), Succeeded()); } TEST_F(ProtocolServerMCPTest, ResourcesList) { @@ -275,9 +261,10 @@ TEST_F(ProtocolServerMCPTest, ResourcesList) { R"json({"id":2,"jsonrpc":"2.0","result":{"resources":[{"description":"description","mimeType":"application/json","name":"name","uri":"lldb://foo/bar"}]}})json"; ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); - llvm::Expected expected_json = json::parse(response); - ASSERT_THAT_EXPECTED(expected_json, llvm::Succeeded()); - EXPECT_THAT_EXPECTED(Run(), HasValue(testing::ElementsAre(*expected_json))); + llvm::Expected expected_resp = json::parse(response); + ASSERT_THAT_EXPECTED(expected_resp, llvm::Succeeded()); + EXPECT_CALL(message_handler, Received(*expected_resp)); + EXPECT_THAT_ERROR(Run(), Succeeded()); } TEST_F(ProtocolServerMCPTest, ToolsCall) { @@ -287,9 +274,10 @@ TEST_F(ProtocolServerMCPTest, ToolsCall) { R"json({"id":11,"jsonrpc":"2.0","result":{"content":[{"text":"foo","type":"text"}],"isError":false}})json"; ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); - llvm::Expected expected_json = json::parse(response); - ASSERT_THAT_EXPECTED(expected_json, llvm::Succeeded()); - ASSERT_THAT_EXPECTED(Run(), HasValue(testing::ElementsAre(*expected_json))); + llvm::Expected expected_resp = json::parse(response); + ASSERT_THAT_EXPECTED(expected_resp, llvm::Succeeded()); + EXPECT_CALL(message_handler, Received(*expected_resp)); + EXPECT_THAT_ERROR(Run(), Succeeded()); } TEST_F(ProtocolServerMCPTest, ToolsCallError) { @@ -301,9 +289,10 @@ TEST_F(ProtocolServerMCPTest, ToolsCallError) { R"json({"error":{"code":-32603,"message":"error"},"id":11,"jsonrpc":"2.0"})json"; ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); - llvm::Expected expected_json = json::parse(response); - ASSERT_THAT_EXPECTED(expected_json, llvm::Succeeded()); - ASSERT_THAT_EXPECTED(Run(), HasValue(testing::ElementsAre(*expected_json))); + llvm::Expected expected_resp = json::parse(response); + ASSERT_THAT_EXPECTED(expected_resp, llvm::Succeeded()); + EXPECT_CALL(message_handler, Received(*expected_resp)); + EXPECT_THAT_ERROR(Run(), Succeeded()); } TEST_F(ProtocolServerMCPTest, ToolsCallFail) { @@ -315,9 +304,10 @@ TEST_F(ProtocolServerMCPTest, ToolsCallFail) { R"json({"id":11,"jsonrpc":"2.0","result":{"content":[{"text":"failed","type":"text"}],"isError":true}})json"; ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); - llvm::Expected expected_json = json::parse(response); - ASSERT_THAT_EXPECTED(expected_json, llvm::Succeeded()); - ASSERT_THAT_EXPECTED(Run(), HasValue(testing::ElementsAre(*expected_json))); + llvm::Expected expected_resp = json::parse(response); + ASSERT_THAT_EXPECTED(expected_resp, llvm::Succeeded()); + EXPECT_CALL(message_handler, Received(*expected_resp)); + EXPECT_THAT_ERROR(Run(), Succeeded()); } TEST_F(ProtocolServerMCPTest, NotificationInitialized) { diff --git a/lldb/unittests/TestingSupport/Host/JSONTransportTestUtilities.h b/lldb/unittests/TestingSupport/Host/JSONTransportTestUtilities.h new file mode 100644 index 0000000000000..5a9eb8e59f2b6 --- /dev/null +++ b/lldb/unittests/TestingSupport/Host/JSONTransportTestUtilities.h @@ -0,0 +1,26 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLDB_UNITTESTS_TESTINGSUPPORT_HOST_NATIVEPROCESSTESTUTILS_H +#define LLDB_UNITTESTS_TESTINGSUPPORT_HOST_NATIVEPROCESSTESTUTILS_H + +#include "lldb/Host/JSONTransport.h" +#include "gmock/gmock.h" + +template +class MockMessageHandler final + : public lldb_private::Transport::MessageHandler { +public: + MOCK_METHOD(void, Received, (const Evt &), (override)); + MOCK_METHOD(void, Received, (const Req &), (override)); + MOCK_METHOD(void, Received, (const Resp &), (override)); + MOCK_METHOD(void, OnError, (llvm::Error), (override)); + MOCK_METHOD(void, OnClosed, (), (override)); +}; + +#endif From b7aacb60a0ba63086ceb926733b57278d357558e Mon Sep 17 00:00:00 2001 From: John Harrison Date: Mon, 18 Aug 2025 17:49:14 -0700 Subject: [PATCH 5/5] Add subsystems initialization to tests. --- lldb/unittests/DAP/TestBase.h | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/lldb/unittests/DAP/TestBase.h b/lldb/unittests/DAP/TestBase.h index 0d70a2bc04082..c19eead4e37e7 100644 --- a/lldb/unittests/DAP/TestBase.h +++ b/lldb/unittests/DAP/TestBase.h @@ -9,6 +9,9 @@ #include "DAP.h" #include "Protocol/ProtocolBase.h" #include "TestingSupport/Host/JSONTransportTestUtilities.h" +#include "TestingSupport/SubsystemRAII.h" +#include "lldb/Host/FileSystem.h" +#include "lldb/Host/HostInfo.h" #include "lldb/Host/MainLoop.h" #include "lldb/Host/MainLoopBase.h" #include "lldb/lldb-forward.h" @@ -76,6 +79,8 @@ class TestTransport final /// messages. class TransportBase : public testing::Test { protected: + lldb_private::SubsystemRAII + subsystems; lldb_private::MainLoop loop; std::unique_ptr transport; MockMessageHandler