Skip to content

Commit 32f5442

Browse files
authored
Add user agent tracking for minimal usage tracking (#229)
1 parent 8c3f20f commit 32f5442

File tree

10 files changed

+294
-56
lines changed

10 files changed

+294
-56
lines changed

cmd/mcp-grafana/main.go

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,30 +6,15 @@ import (
66
"fmt"
77
"log/slog"
88
"os"
9-
"runtime/debug"
109
"slices"
1110
"strings"
12-
"sync"
1311

1412
"github.com/mark3labs/mcp-go/server"
1513

1614
mcpgrafana "github.com/grafana/mcp-grafana"
1715
"github.com/grafana/mcp-grafana/tools"
1816
)
1917

20-
// version returns the version of the mcp-grafana binary.
21-
// It is populated by the `runtime/debug` package which
22-
// fetches git information from the build directory.
23-
var version = sync.OnceValue(func() string {
24-
// Default version string returned by `runtime/debug` if built
25-
// from the source repository rather than with `go install`.
26-
v := "(devel)"
27-
if bi, ok := debug.ReadBuildInfo(); ok {
28-
v = bi.Main.Version
29-
}
30-
return v
31-
})
32-
3318
func maybeAddTools(s *server.MCPServer, tf func(*server.MCPServer), enabledTools []string, disable bool, category string) {
3419
if !slices.Contains(enabledTools, category) {
3520
slog.Debug("Not enabling tools", "category", category)
@@ -111,7 +96,7 @@ func (dt *disabledTools) addTools(s *server.MCPServer) {
11196
}
11297

11398
func newServer(dt disabledTools) *server.MCPServer {
114-
s := server.NewMCPServer("mcp-grafana", version(), server.WithInstructions(`
99+
s := server.NewMCPServer("mcp-grafana", mcpgrafana.Version(), server.WithInstructions(`
115100
This server provides access to your Grafana instance and the surrounding ecosystem.
116101
117102
Available Capabilities:
@@ -138,14 +123,14 @@ func run(transport, addr, basePath, endpointPath string, logLevel slog.Level, dt
138123
case "stdio":
139124
srv := server.NewStdioServer(s)
140125
srv.SetContextFunc(mcpgrafana.ComposedStdioContextFunc(gc))
141-
slog.Info("Starting Grafana MCP server using stdio transport", "version", version())
126+
slog.Info("Starting Grafana MCP server using stdio transport", "version", mcpgrafana.Version())
142127
return srv.Listen(context.Background(), os.Stdin, os.Stdout)
143128
case "sse":
144129
srv := server.NewSSEServer(s,
145130
server.WithSSEContextFunc(mcpgrafana.ComposedSSEContextFunc(gc)),
146131
server.WithStaticBasePath(basePath),
147132
)
148-
slog.Info("Starting Grafana MCP server using SSE transport", "version", version(), "address", addr, "basePath", basePath)
133+
slog.Info("Starting Grafana MCP server using SSE transport", "version", mcpgrafana.Version(), "address", addr, "basePath", basePath)
149134
if err := srv.Start(addr); err != nil {
150135
return fmt.Errorf("server error: %v", err)
151136
}
@@ -154,7 +139,7 @@ func run(transport, addr, basePath, endpointPath string, logLevel slog.Level, dt
154139
server.WithStateLess(true),
155140
server.WithEndpointPath(endpointPath),
156141
)
157-
slog.Info("Starting Grafana MCP server using StreamableHTTP transport", "version", version(), "address", addr, "endpointPath", endpointPath)
142+
slog.Info("Starting Grafana MCP server using StreamableHTTP transport", "version", mcpgrafana.Version(), "address", addr, "endpointPath", endpointPath)
158143
if err := srv.Start(addr); err != nil {
159144
return fmt.Errorf("server error: %v", err)
160145
}
@@ -188,7 +173,7 @@ func main() {
188173
flag.Parse()
189174

190175
if *showVersion {
191-
fmt.Println(version())
176+
fmt.Println(mcpgrafana.Version())
192177
os.Exit(0)
193178
}
194179

mcpgrafana.go

Lines changed: 76 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@ import (
1010
"net/url"
1111
"os"
1212
"reflect"
13+
"runtime/debug"
1314
"strings"
15+
"sync"
1416

1517
"github.com/go-openapi/strfmt"
1618
"github.com/grafana/grafana-openapi-client-go/client"
@@ -59,7 +61,7 @@ type GrafanaConfig struct {
5961
Debug bool
6062

6163
// IncludeArgumentsInSpans enables logging of tool arguments in OpenTelemetry spans.
62-
// This should only be enabled in non-production environments or when you're certain
64+
// This should only be enabled in non-production environments or when you're certain
6365
// the arguments don't contain PII. Defaults to false for safety.
6466
// Note: OpenTelemetry spans are always created for context propagation, but arguments
6567
// are only included when this flag is enabled.
@@ -147,6 +149,65 @@ func (tc *TLSConfig) HTTPTransport(defaultTransport *http.Transport) (http.Round
147149
return transport, nil
148150
}
149151

152+
// UserAgentTransport wraps an http.RoundTripper to add a custom User-Agent header
153+
type UserAgentTransport struct {
154+
rt http.RoundTripper
155+
UserAgent string
156+
}
157+
158+
func (t *UserAgentTransport) RoundTrip(req *http.Request) (*http.Response, error) {
159+
// Clone the request to avoid modifying the original
160+
clonedReq := req.Clone(req.Context())
161+
162+
// Add or update the User-Agent header
163+
if clonedReq.Header.Get("User-Agent") == "" {
164+
clonedReq.Header.Set("User-Agent", t.UserAgent)
165+
}
166+
167+
return t.rt.RoundTrip(clonedReq)
168+
}
169+
170+
// Version returns the version of the mcp-grafana binary.
171+
// It is populated by the `runtime/debug` package which
172+
// fetches git information from the build directory.
173+
var Version = sync.OnceValue(func() string {
174+
// Default version string returned by `runtime/debug` if built
175+
// from the source repository rather than with `go install`.
176+
v := "(devel)"
177+
if bi, ok := debug.ReadBuildInfo(); ok && bi.Main.Version != "" {
178+
v = bi.Main.Version
179+
}
180+
return v
181+
})
182+
183+
// UserAgent returns the user agent string for HTTP requests
184+
func UserAgent() string {
185+
return fmt.Sprintf("mcp-grafana/%s", Version())
186+
}
187+
188+
// NewUserAgentTransport creates a new UserAgentTransport with the specified user agent.
189+
// If no user agent is provided, uses the default UserAgent().
190+
func NewUserAgentTransport(rt http.RoundTripper, userAgent ...string) *UserAgentTransport {
191+
if rt == nil {
192+
rt = http.DefaultTransport
193+
}
194+
195+
ua := UserAgent() // default
196+
if len(userAgent) > 0 {
197+
ua = userAgent[0]
198+
}
199+
200+
return &UserAgentTransport{
201+
rt: rt,
202+
UserAgent: ua,
203+
}
204+
}
205+
206+
// wrapWithUserAgent wraps an http.RoundTripper with user agent tracking
207+
func wrapWithUserAgent(rt http.RoundTripper) http.RoundTripper {
208+
return NewUserAgentTransport(rt)
209+
}
210+
150211
// ExtractGrafanaInfoFromEnv is a StdioContextFunc that extracts Grafana configuration
151212
// from environment variables and injects a configured client into the context.
152213
var ExtractGrafanaInfoFromEnv server.StdioContextFunc = func(ctx context.Context) context.Context {
@@ -281,9 +342,11 @@ func NewGrafanaClient(ctx context.Context, grafanaURL, apiKey string) *client.Gr
281342
transportField := v.FieldByName("Transport")
282343
if transportField.IsValid() && transportField.CanSet() {
283344
if rt, ok := transportField.Interface().(http.RoundTripper); ok {
284-
wrapped := otelhttp.NewTransport(rt)
345+
// Wrap with user agent first, then otel
346+
userAgentWrapped := wrapWithUserAgent(rt)
347+
wrapped := otelhttp.NewTransport(userAgentWrapped)
285348
transportField.Set(reflect.ValueOf(wrapped))
286-
slog.Debug("HTTP tracing enabled for Grafana client")
349+
slog.Debug("HTTP tracing and user agent tracking enabled for Grafana client")
287350
}
288351
}
289352
}
@@ -363,12 +426,15 @@ var ExtractIncidentClientFromEnv server.StdioContextFunc = func(ctx context.Cont
363426
if err != nil {
364427
slog.Error("Failed to create custom transport for incident client, using default", "error", err)
365428
} else {
366-
client.HTTPClient.Transport = transport
367-
slog.Debug("Using custom TLS configuration for incident client",
429+
client.HTTPClient.Transport = wrapWithUserAgent(transport)
430+
slog.Debug("Using custom TLS configuration and user agent for incident client",
368431
"cert_file", tlsConfig.CertFile,
369432
"ca_file", tlsConfig.CAFile,
370433
"skip_verify", tlsConfig.SkipVerify)
371434
}
435+
} else {
436+
// No custom TLS, but still add user agent
437+
client.HTTPClient.Transport = wrapWithUserAgent(http.DefaultTransport)
372438
}
373439

374440
return context.WithValue(ctx, incidentClientKey{}, client)
@@ -395,12 +461,15 @@ var ExtractIncidentClientFromHeaders httpContextFunc = func(ctx context.Context,
395461
if err != nil {
396462
slog.Error("Failed to create custom transport for incident client, using default", "error", err)
397463
} else {
398-
client.HTTPClient.Transport = transport
399-
slog.Debug("Using custom TLS configuration for incident client",
464+
client.HTTPClient.Transport = wrapWithUserAgent(transport)
465+
slog.Debug("Using custom TLS configuration and user agent for incident client",
400466
"cert_file", tlsConfig.CertFile,
401467
"ca_file", tlsConfig.CAFile,
402468
"skip_verify", tlsConfig.SkipVerify)
403469
}
470+
} else {
471+
// No custom TLS, but still add user agent
472+
client.HTTPClient.Transport = wrapWithUserAgent(http.DefaultTransport)
404473
}
405474

406475
return context.WithValue(ctx, incidentClientKey{}, client)

mcpgrafana_test.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ func TestExtractGrafanaInfoFromHeaders(t *testing.T) {
8080
// Explicitly clear environment variables to ensure test isolation
8181
t.Setenv("GRAFANA_URL", "")
8282
t.Setenv("GRAFANA_API_KEY", "")
83-
83+
8484
req, err := http.NewRequest("GET", "http://example.com", nil)
8585
require.NoError(t, err)
8686
ctx := ExtractGrafanaInfoFromHeaders(context.Background(), req)
@@ -239,7 +239,7 @@ func TestToolTracingInstrumentation(t *testing.T) {
239239
type TestParams struct {
240240
Message string `json:"message" jsonschema:"description=Test message"`
241241
}
242-
242+
243243
testHandler := func(ctx context.Context, args TestParams) (string, error) {
244244
return "Hello " + args.Message, nil
245245
}
@@ -295,7 +295,7 @@ func TestToolTracingInstrumentation(t *testing.T) {
295295
type TestParams struct {
296296
ShouldFail bool `json:"shouldFail" jsonschema:"description=Whether to fail"`
297297
}
298-
298+
299299
testHandler := func(ctx context.Context, args TestParams) (string, error) {
300300
if args.ShouldFail {
301301
return "", assert.AnError
@@ -358,7 +358,7 @@ func TestToolTracingInstrumentation(t *testing.T) {
358358
type TestParams struct {
359359
Message string `json:"message" jsonschema:"description=Test message"`
360360
}
361-
361+
362362
testHandler := func(ctx context.Context, args TestParams) (string, error) {
363363
return "processed", nil
364364
}
@@ -406,7 +406,7 @@ func TestToolTracingInstrumentation(t *testing.T) {
406406
type TestParams struct {
407407
SensitiveData string `json:"sensitiveData" jsonschema:"description=Potentially sensitive data"`
408408
}
409-
409+
410410
testHandler := func(ctx context.Context, args TestParams) (string, error) {
411411
return "processed", nil
412412
}
@@ -451,7 +451,7 @@ func TestToolTracingInstrumentation(t *testing.T) {
451451
attributes := span.Attributes()
452452
assertHasAttribute(t, attributes, "mcp.tool.name", "sensitive_tool")
453453
assertHasAttribute(t, attributes, "mcp.tool.description", "A tool with sensitive data")
454-
454+
455455
// Verify arguments are NOT present
456456
for _, attr := range attributes {
457457
assert.NotEqual(t, "mcp.tool.arguments", string(attr.Key), "Arguments should not be logged by default for PII safety")
@@ -466,7 +466,7 @@ func TestToolTracingInstrumentation(t *testing.T) {
466466
type TestParams struct {
467467
SafeData string `json:"safeData" jsonschema:"description=Non-sensitive data"`
468468
}
469-
469+
470470
testHandler := func(ctx context.Context, args TestParams) (string, error) {
471471
return "processed", nil
472472
}
@@ -531,7 +531,7 @@ func TestHTTPTracingConfiguration(t *testing.T) {
531531

532532
t.Run("tracing works gracefully without OpenTelemetry configured", func(t *testing.T) {
533533
// No OpenTelemetry tracer provider configured
534-
534+
535535
// Create context (tracing always enabled for context propagation)
536536
config := GrafanaConfig{}
537537
ctx := WithGrafanaConfig(context.Background(), config)

0 commit comments

Comments
 (0)