Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
"net/http"
"slices"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -145,6 +146,7 @@ func (c *Client) sendRequest(
ctx context.Context,
method string,
params any,
header http.Header,
) (*json.RawMessage, error) {
if !c.initialized && method != "initialize" {
return nil, fmt.Errorf("client not initialized")
Expand All @@ -157,6 +159,7 @@ func (c *Client) sendRequest(
ID: mcp.NewRequestId(id),
Method: method,
Params: params,
Header: header,
}

response, err := c.transport.SendRequest(ctx, request)
Expand Down Expand Up @@ -198,7 +201,7 @@ func (c *Client) Initialize(
Capabilities: capabilities,
}

response, err := c.sendRequest(ctx, "initialize", params)
response, err := c.sendRequest(ctx, "initialize", params, request.Header)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -243,7 +246,7 @@ func (c *Client) Initialize(
}

func (c *Client) Ping(ctx context.Context) error {
_, err := c.sendRequest(ctx, "ping", nil)
_, err := c.sendRequest(ctx, "ping", nil, nil)
return err
}

Expand Down Expand Up @@ -324,7 +327,7 @@ func (c *Client) ReadResource(
ctx context.Context,
request mcp.ReadResourceRequest,
) (*mcp.ReadResourceResult, error) {
response, err := c.sendRequest(ctx, "resources/read", request.Params)
response, err := c.sendRequest(ctx, "resources/read", request.Params, request.Header)
if err != nil {
return nil, err
}
Expand All @@ -336,15 +339,15 @@ func (c *Client) Subscribe(
ctx context.Context,
request mcp.SubscribeRequest,
) error {
_, err := c.sendRequest(ctx, "resources/subscribe", request.Params)
_, err := c.sendRequest(ctx, "resources/subscribe", request.Params, request.Header)
return err
}

func (c *Client) Unsubscribe(
ctx context.Context,
request mcp.UnsubscribeRequest,
) error {
_, err := c.sendRequest(ctx, "resources/unsubscribe", request.Params)
_, err := c.sendRequest(ctx, "resources/unsubscribe", request.Params, request.Header)
return err
}

Expand Down Expand Up @@ -388,7 +391,7 @@ func (c *Client) GetPrompt(
ctx context.Context,
request mcp.GetPromptRequest,
) (*mcp.GetPromptResult, error) {
response, err := c.sendRequest(ctx, "prompts/get", request.Params)
response, err := c.sendRequest(ctx, "prompts/get", request.Params, request.Header)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -436,7 +439,7 @@ func (c *Client) CallTool(
ctx context.Context,
request mcp.CallToolRequest,
) (*mcp.CallToolResult, error) {
response, err := c.sendRequest(ctx, "tools/call", request.Params)
response, err := c.sendRequest(ctx, "tools/call", request.Params, request.Header)
if err != nil {
return nil, err
}
Expand All @@ -448,15 +451,15 @@ func (c *Client) SetLevel(
ctx context.Context,
request mcp.SetLevelRequest,
) error {
_, err := c.sendRequest(ctx, "logging/setLevel", request.Params)
_, err := c.sendRequest(ctx, "logging/setLevel", request.Params, request.Header)
return err
}

func (c *Client) Complete(
ctx context.Context,
request mcp.CompleteRequest,
) (*mcp.CompleteResult, error) {
response, err := c.sendRequest(ctx, "completion/complete", request.Params)
response, err := c.sendRequest(ctx, "completion/complete", request.Params, request.Header)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -583,7 +586,7 @@ func listByPage[T any](
request mcp.PaginatedRequest,
method string,
) (*T, error) {
response, err := client.sendRequest(ctx, method, request.Params)
response, err := client.sendRequest(ctx, method, request.Params, nil)
if err != nil {
return nil, err
}
Expand Down
2 changes: 2 additions & 0 deletions client/transport/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package transport
import (
"context"
"encoding/json"
"net/http"

"github.com/mark3labs/mcp-go/mcp"
)
Expand Down Expand Up @@ -59,6 +60,7 @@ type JSONRPCRequest struct {
ID mcp.RequestId `json:"id"`
Method string `json:"method"`
Params any `json:"params,omitempty"`
Header http.Header `json:"-"`
}

// JSONRPCResponse represents a JSON-RPC 2.0 response message.
Expand Down
12 changes: 9 additions & 3 deletions client/transport/streamable_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@
ctx, cancel := c.contextAwareOfClientClose(ctx)
defer cancel()

resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(requestBody), "application/json, text/event-stream")
resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(requestBody), "application/json, text/event-stream", request.Header)
if err != nil {
if errors.Is(err, ErrSessionTerminated) && request.Method == string(mcp.MethodInitialize) {
// If the request is initialize, should not return a SessionTerminated error
Expand Down Expand Up @@ -339,13 +339,19 @@
method string,
body io.Reader,
acceptType string,
header http.Header,
) (resp *http.Response, err error) {
// Create HTTP request
req, err := http.NewRequestWithContext(ctx, method, c.serverURL.String(), body)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}

// request headers
if header != nil {
req.Header = header
}

Comment on lines +350 to +354
Copy link
Contributor

@coderabbitai coderabbitai bot Aug 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Header assignment may not preserve all default headers.

When a custom header is provided, the code replaces the entire req.Header with the provided header. This could potentially lose important default headers that http.NewRequestWithContext sets up (like Host, User-Agent, etc.).

Consider merging headers instead of replacing:

-	// request headers
-	if header != nil {
-		req.Header = header
-	}
-
+	// Merge custom headers if provided
+	if header != nil {
+		for key, values := range header {
+			for _, value := range values {
+				req.Header.Add(key, value)
+			}
+		}
+	}

Alternatively, if the intention is to allow complete header control, consider cloning the header first to avoid modifying the original:

-	// request headers
-	if header != nil {
-		req.Header = header
-	}
-
+	// Use custom headers if provided
+	if header != nil {
+		req.Header = header.Clone()
+	}
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
// request headers
if header != nil {
req.Header = header
}
// Merge custom headers if provided
if header != nil {
for key, values := range header {
for _, value := range values {
req.Header.Add(key, value)
}
}
}
🤖 Prompt for AI Agents
In client/transport/streamable_http.go around lines 350 to 354, the code
replaces req.Header with the provided header which can discard default headers;
instead merge the provided header into the existing req.Header (iterating keys
and appending values) so you preserve defaults like Host and User-Agent, or if
full control was intended, first clone the provided header and assign the clone
to req.Header to avoid mutating the caller's header; implement one of these two
approaches and ensure you append values rather than overwrite per header key.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@matthisholleville could you address this please?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like the humans are having a chat. I'll hop back into my burrow for now. If you need me again, just tag @coderabbitai in a new comment, and I'll come hopping out!

// Set headers
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", acceptType)
Expand Down Expand Up @@ -539,7 +545,7 @@
ctx, cancel := c.contextAwareOfClientClose(ctx)
defer cancel()

resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(requestBody), "application/json, text/event-stream")
resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(requestBody), "application/json, text/event-stream", nil)
if err != nil {
return fmt.Errorf("failed to send request: %w", err)
}
Expand Down Expand Up @@ -631,7 +637,7 @@
)

func (c *StreamableHTTP) createGETConnectionToServer(ctx context.Context) error {
resp, err := c.sendHTTP(ctx, http.MethodGet, nil, "text/event-stream")
resp, err := c.sendHTTP(ctx, http.MethodGet, nil, "text/event-stream", nil)
if err != nil {
return fmt.Errorf("failed to send request: %w", err)
}
Expand Down Expand Up @@ -746,7 +752,7 @@
ctx, cancel := c.contextAwareOfClientClose(ctx)
defer cancel()

resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(responseBody), "application/json, text/event-stream")

Check failure on line 755 in client/transport/streamable_http.go

View workflow job for this annotation

GitHub Actions / test

not enough arguments in call to c.sendHTTP

Check failure on line 755 in client/transport/streamable_http.go

View workflow job for this annotation

GitHub Actions / lint

not enough arguments in call to c.sendHTTP

Check failure on line 755 in client/transport/streamable_http.go

View workflow job for this annotation

GitHub Actions / lint

not enough arguments in call to c.sendHTTP

Check failure on line 755 in client/transport/streamable_http.go

View workflow job for this annotation

GitHub Actions / lint

not enough arguments in call to c.sendHTTP

Check failure on line 755 in client/transport/streamable_http.go

View workflow job for this annotation

GitHub Actions / lint

not enough arguments in call to c.sendHTTP
if err != nil {
c.logger.Errorf("failed to send response to server: %v", err)
return
Expand Down
53 changes: 53 additions & 0 deletions client/transport/streamable_http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ func startMockStreamableHTTPServer() (string, func()) {
"jsonrpc": "2.0",
"id": request["id"],
"result": request,
"headers": r.Header,
}); err != nil {
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
return
Expand Down Expand Up @@ -122,6 +123,24 @@ func startMockStreamableHTTPServer() (string, func()) {
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
return
}
case "debug/echo_header":
// Check session ID
if r.Header.Get("Mcp-Session-Id") != sessionID {
http.Error(w, "Invalid session ID", http.StatusNotFound)
return
}

// Echo back the request headersas the response result
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(map[string]any{
"jsonrpc": "2.0",
"id": request["id"],
"result": r.Header,
}); err != nil {
http.Error(w, "Failed to encode response", http.StatusInternalServerError)
return
}
}
})

Expand Down Expand Up @@ -215,6 +234,40 @@ func TestStreamableHTTP(t *testing.T) {
}
})

t.Run("SendRequestWithHeader", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

params := map[string]any{
"string": "hello world",
"array": []any{1, 2, 3},
}

request := JSONRPCRequest{
JSONRPC: "2.0",
ID: mcp.NewRequestId(int64(1)),
Method: "debug/echo_header",
Params: params,
Header: http.Header{"X-Test-Header": {"test-header-value"}},
}

// Send the request
response, err := trans.SendRequest(ctx, request)
if err != nil {
t.Fatalf("SendRequest failed: %v", err)
}

// Parse the result to verify echo
var result map[string]any
if err := json.Unmarshal(response.Result, &result); err != nil {
t.Fatalf("Failed to unmarshal result: %v", err)
}

if hdr, ok := result["X-Test-Header"].([]any); !ok || len(hdr) == 0 || hdr[0] != "test-header-value" {
t.Errorf("Expected X-Test-Header to be ['test-header-value'], got %v", result["X-Test-Header"])
}
})

t.Run("SendRequestWithTimeout", func(t *testing.T) {
// Create a context that's already canceled
ctx, cancel := context.WithCancel(context.Background())
Expand Down
Loading