diff --git a/README.md b/README.md index ba05d8431..d75373879 100644 --- a/README.md +++ b/README.md @@ -285,7 +285,7 @@ The MongoDB MCP Server can be configured using multiple methods, with the follow | `connectionString` | MongoDB connection string for direct database connections. Optional, if not set, you'll need to call the `connect` tool before interacting with MongoDB data. | | `logPath` | Folder to store logs. | | `disabledTools` | An array of tool names, operation types, and/or categories of tools that will be disabled. | -| `readOnly` | When set to true, only allows read and metadata operation types, disabling create/update/delete operations. | +| `readOnly` | When set to true, only allows read, connect, and metadata operation types, disabling create/update/delete operations. | | `indexCheck` | When set to true, enforces that query operations must use an index, rejecting queries that perform a collection scan. | | `telemetry` | When set to disabled, disables telemetry collection. | @@ -318,10 +318,11 @@ Operation types: - `delete` - Tools that delete resources, such as delete document, drop collection, etc. - `read` - Tools that read resources, such as find, aggregate, list clusters, etc. - `metadata` - Tools that read metadata, such as list databases, list collections, collection schema, etc. +- `connect` - Tools that allow you to connect or switch the connection to a MongoDB instance. If this is disabled, you will need to provide a connection string through the config when starting the server. #### Read-Only Mode -The `readOnly` configuration option allows you to restrict the MCP server to only use tools with "read" and "metadata" operation types. When enabled, all tools that have "create", "update" or "delete" operation types will not be registered with the server. +The `readOnly` configuration option allows you to restrict the MCP server to only use tools with "read", "connect", and "metadata" operation types. When enabled, all tools that have "create", "update" or "delete" operation types will not be registered with the server. This is useful for scenarios where you want to provide access to MongoDB data for analysis without allowing any modifications to the data or infrastructure. diff --git a/src/server.ts b/src/server.ts index 31a99ded7..c32dc367d 100644 --- a/src/server.ts +++ b/src/server.ts @@ -12,6 +12,7 @@ import { type ServerCommand } from "./telemetry/types.js"; import { CallToolRequestSchema, CallToolResult } from "@modelcontextprotocol/sdk/types.js"; import assert from "assert"; import { detectContainerEnv } from "./common/container.js"; +import { ToolBase } from "./tools/tool.js"; export interface ServerOptions { session: Session; @@ -22,9 +23,10 @@ export interface ServerOptions { export class Server { public readonly session: Session; - private readonly mcpServer: McpServer; + public readonly mcpServer: McpServer; private readonly telemetry: Telemetry; public readonly userConfig: UserConfig; + public readonly tools: ToolBase[] = []; private readonly startTime: number; constructor({ session, mcpServer, userConfig, telemetry }: ServerOptions) { @@ -141,8 +143,11 @@ export class Server { } private registerTools() { - for (const tool of [...AtlasTools, ...MongoDbTools]) { - new tool(this.session, this.userConfig, this.telemetry).register(this.mcpServer); + for (const toolConstructor of [...AtlasTools, ...MongoDbTools]) { + const tool = new toolConstructor(this.session, this.userConfig, this.telemetry); + if (tool.register(this)) { + this.tools.push(tool); + } } } diff --git a/src/tools/atlas/atlasTool.ts b/src/tools/atlas/atlasTool.ts index 2b93a5ec4..eb7c2f1f4 100644 --- a/src/tools/atlas/atlasTool.ts +++ b/src/tools/atlas/atlasTool.ts @@ -6,7 +6,7 @@ import { z } from "zod"; import { ApiClientError } from "../../common/atlas/apiClientError.js"; export abstract class AtlasToolBase extends ToolBase { - protected category: ToolCategory = "atlas"; + public category: ToolCategory = "atlas"; protected verifyAllowed(): boolean { if (!this.config.apiClientId || !this.config.apiClientSecret) { @@ -29,7 +29,7 @@ export abstract class AtlasToolBase extends ToolBase { type: "text", text: `Unable to authenticate with MongoDB Atlas, API error: ${error.message} -Hint: Your API credentials may be invalid, expired or lack permissions. +Hint: Your API credentials may be invalid, expired or lack permissions. Please check your Atlas API credentials and ensure they have the appropriate permissions. For more information on setting up API keys, visit: https://www.mongodb.com/docs/atlas/configure-api-access/`, }, @@ -44,7 +44,7 @@ For more information on setting up API keys, visit: https://www.mongodb.com/docs { type: "text", text: `Received a Forbidden API Error: ${error.message} - + You don't have sufficient permissions to perform this action in MongoDB Atlas Please ensure your API key has the necessary roles assigned. For more information on Atlas API access roles, visit: https://www.mongodb.com/docs/atlas/api/service-accounts-overview/`, diff --git a/src/tools/atlas/metadata/connectCluster.ts b/src/tools/atlas/connect/connectCluster.ts similarity index 98% rename from src/tools/atlas/metadata/connectCluster.ts rename to src/tools/atlas/connect/connectCluster.ts index a65913a61..31113e822 100644 --- a/src/tools/atlas/metadata/connectCluster.ts +++ b/src/tools/atlas/connect/connectCluster.ts @@ -13,9 +13,9 @@ function sleep(ms: number): Promise { } export class ConnectClusterTool extends AtlasToolBase { - protected name = "atlas-connect-cluster"; + public name = "atlas-connect-cluster"; protected description = "Connect to MongoDB Atlas cluster"; - protected operationType: OperationType = "metadata"; + public operationType: OperationType = "connect"; protected argsShape = { projectId: z.string().describe("Atlas project ID"), clusterName: z.string().describe("Atlas cluster name"), diff --git a/src/tools/atlas/create/createAccessList.ts b/src/tools/atlas/create/createAccessList.ts index 1c38279a7..4941b1e8c 100644 --- a/src/tools/atlas/create/createAccessList.ts +++ b/src/tools/atlas/create/createAccessList.ts @@ -6,9 +6,9 @@ import { ToolArgs, OperationType } from "../../tool.js"; const DEFAULT_COMMENT = "Added by Atlas MCP"; export class CreateAccessListTool extends AtlasToolBase { - protected name = "atlas-create-access-list"; + public name = "atlas-create-access-list"; protected description = "Allow Ip/CIDR ranges to access your MongoDB Atlas clusters."; - protected operationType: OperationType = "create"; + public operationType: OperationType = "create"; protected argsShape = { projectId: z.string().describe("Atlas project ID"), ipAddresses: z diff --git a/src/tools/atlas/create/createDBUser.ts b/src/tools/atlas/create/createDBUser.ts index a8266a0a1..fef9d513d 100644 --- a/src/tools/atlas/create/createDBUser.ts +++ b/src/tools/atlas/create/createDBUser.ts @@ -6,9 +6,9 @@ import { CloudDatabaseUser, DatabaseUserRole } from "../../../common/atlas/opena import { generateSecurePassword } from "../../../common/atlas/generatePassword.js"; export class CreateDBUserTool extends AtlasToolBase { - protected name = "atlas-create-db-user"; + public name = "atlas-create-db-user"; protected description = "Create an MongoDB Atlas database user"; - protected operationType: OperationType = "create"; + public operationType: OperationType = "create"; protected argsShape = { projectId: z.string().describe("Atlas project ID"), username: z.string().describe("Username for the new user"), diff --git a/src/tools/atlas/create/createFreeCluster.ts b/src/tools/atlas/create/createFreeCluster.ts index 2d93ae801..ed04409b0 100644 --- a/src/tools/atlas/create/createFreeCluster.ts +++ b/src/tools/atlas/create/createFreeCluster.ts @@ -5,9 +5,9 @@ import { ToolArgs, OperationType } from "../../tool.js"; import { ClusterDescription20240805 } from "../../../common/atlas/openapi.js"; export class CreateFreeClusterTool extends AtlasToolBase { - protected name = "atlas-create-free-cluster"; + public name = "atlas-create-free-cluster"; protected description = "Create a free MongoDB Atlas cluster"; - protected operationType: OperationType = "create"; + public operationType: OperationType = "create"; protected argsShape = { projectId: z.string().describe("Atlas project ID to create the cluster in"), name: z.string().describe("Name of the cluster"), diff --git a/src/tools/atlas/create/createProject.ts b/src/tools/atlas/create/createProject.ts index cdf71b9c6..29bff3f6c 100644 --- a/src/tools/atlas/create/createProject.ts +++ b/src/tools/atlas/create/createProject.ts @@ -5,9 +5,9 @@ import { ToolArgs, OperationType } from "../../tool.js"; import { Group } from "../../../common/atlas/openapi.js"; export class CreateProjectTool extends AtlasToolBase { - protected name = "atlas-create-project"; + public name = "atlas-create-project"; protected description = "Create a MongoDB Atlas project"; - protected operationType: OperationType = "create"; + public operationType: OperationType = "create"; protected argsShape = { projectName: z.string().optional().describe("Name for the new project"), organizationId: z.string().optional().describe("Organization ID for the new project"), diff --git a/src/tools/atlas/read/inspectAccessList.ts b/src/tools/atlas/read/inspectAccessList.ts index 94c852280..13e027c95 100644 --- a/src/tools/atlas/read/inspectAccessList.ts +++ b/src/tools/atlas/read/inspectAccessList.ts @@ -4,9 +4,9 @@ import { AtlasToolBase } from "../atlasTool.js"; import { ToolArgs, OperationType } from "../../tool.js"; export class InspectAccessListTool extends AtlasToolBase { - protected name = "atlas-inspect-access-list"; + public name = "atlas-inspect-access-list"; protected description = "Inspect Ip/CIDR ranges with access to your MongoDB Atlas clusters."; - protected operationType: OperationType = "read"; + public operationType: OperationType = "read"; protected argsShape = { projectId: z.string().describe("Atlas project ID"), }; diff --git a/src/tools/atlas/read/inspectCluster.ts b/src/tools/atlas/read/inspectCluster.ts index c73c1b76f..a4209fd5f 100644 --- a/src/tools/atlas/read/inspectCluster.ts +++ b/src/tools/atlas/read/inspectCluster.ts @@ -5,9 +5,9 @@ import { ToolArgs, OperationType } from "../../tool.js"; import { Cluster, inspectCluster } from "../../../common/atlas/cluster.js"; export class InspectClusterTool extends AtlasToolBase { - protected name = "atlas-inspect-cluster"; + public name = "atlas-inspect-cluster"; protected description = "Inspect MongoDB Atlas cluster"; - protected operationType: OperationType = "read"; + public operationType: OperationType = "read"; protected argsShape = { projectId: z.string().describe("Atlas project ID"), clusterName: z.string().describe("Atlas cluster name"), diff --git a/src/tools/atlas/read/listAlerts.ts b/src/tools/atlas/read/listAlerts.ts index bbbf6f142..dcf56a63d 100644 --- a/src/tools/atlas/read/listAlerts.ts +++ b/src/tools/atlas/read/listAlerts.ts @@ -4,9 +4,9 @@ import { AtlasToolBase } from "../atlasTool.js"; import { ToolArgs, OperationType } from "../../tool.js"; export class ListAlertsTool extends AtlasToolBase { - protected name = "atlas-list-alerts"; + public name = "atlas-list-alerts"; protected description = "List MongoDB Atlas alerts"; - protected operationType: OperationType = "read"; + public operationType: OperationType = "read"; protected argsShape = { projectId: z.string().describe("Atlas project ID to list alerts for"), }; diff --git a/src/tools/atlas/read/listClusters.ts b/src/tools/atlas/read/listClusters.ts index a8af8828a..99c26fe62 100644 --- a/src/tools/atlas/read/listClusters.ts +++ b/src/tools/atlas/read/listClusters.ts @@ -11,9 +11,9 @@ import { import { formatCluster, formatFlexCluster } from "../../../common/atlas/cluster.js"; export class ListClustersTool extends AtlasToolBase { - protected name = "atlas-list-clusters"; + public name = "atlas-list-clusters"; protected description = "List MongoDB Atlas clusters"; - protected operationType: OperationType = "read"; + public operationType: OperationType = "read"; protected argsShape = { projectId: z.string().describe("Atlas project ID to filter clusters").optional(), }; diff --git a/src/tools/atlas/read/listDBUsers.ts b/src/tools/atlas/read/listDBUsers.ts index 7650cbf0c..57344d652 100644 --- a/src/tools/atlas/read/listDBUsers.ts +++ b/src/tools/atlas/read/listDBUsers.ts @@ -5,9 +5,9 @@ import { ToolArgs, OperationType } from "../../tool.js"; import { DatabaseUserRole, UserScope } from "../../../common/atlas/openapi.js"; export class ListDBUsersTool extends AtlasToolBase { - protected name = "atlas-list-db-users"; + public name = "atlas-list-db-users"; protected description = "List MongoDB Atlas database users"; - protected operationType: OperationType = "read"; + public operationType: OperationType = "read"; protected argsShape = { projectId: z.string().describe("Atlas project ID to filter DB users"), }; diff --git a/src/tools/atlas/read/listOrgs.ts b/src/tools/atlas/read/listOrgs.ts index c55738d76..66b4c9684 100644 --- a/src/tools/atlas/read/listOrgs.ts +++ b/src/tools/atlas/read/listOrgs.ts @@ -3,9 +3,9 @@ import { AtlasToolBase } from "../atlasTool.js"; import { OperationType } from "../../tool.js"; export class ListOrganizationsTool extends AtlasToolBase { - protected name = "atlas-list-orgs"; + public name = "atlas-list-orgs"; protected description = "List MongoDB Atlas organizations"; - protected operationType: OperationType = "read"; + public operationType: OperationType = "read"; protected argsShape = {}; protected async execute(): Promise { diff --git a/src/tools/atlas/read/listProjects.ts b/src/tools/atlas/read/listProjects.ts index 1a9ab523f..e8fc02491 100644 --- a/src/tools/atlas/read/listProjects.ts +++ b/src/tools/atlas/read/listProjects.ts @@ -5,9 +5,9 @@ import { z } from "zod"; import { ToolArgs } from "../../tool.js"; export class ListProjectsTool extends AtlasToolBase { - protected name = "atlas-list-projects"; + public name = "atlas-list-projects"; protected description = "List MongoDB Atlas projects"; - protected operationType: OperationType = "read"; + public operationType: OperationType = "read"; protected argsShape = { orgId: z.string().describe("Atlas organization ID to filter projects").optional(), }; diff --git a/src/tools/atlas/tools.ts b/src/tools/atlas/tools.ts index 9c27740d9..c43b88ef7 100644 --- a/src/tools/atlas/tools.ts +++ b/src/tools/atlas/tools.ts @@ -8,7 +8,7 @@ import { ListDBUsersTool } from "./read/listDBUsers.js"; import { CreateDBUserTool } from "./create/createDBUser.js"; import { CreateProjectTool } from "./create/createProject.js"; import { ListOrganizationsTool } from "./read/listOrgs.js"; -import { ConnectClusterTool } from "./metadata/connectCluster.js"; +import { ConnectClusterTool } from "./connect/connectCluster.js"; import { ListAlertsTool } from "./read/listAlerts.js"; export const AtlasTools = [ diff --git a/src/tools/mongodb/metadata/connect.ts b/src/tools/mongodb/connect/connect.ts similarity index 90% rename from src/tools/mongodb/metadata/connect.ts rename to src/tools/mongodb/connect/connect.ts index 578220014..e8de93339 100644 --- a/src/tools/mongodb/metadata/connect.ts +++ b/src/tools/mongodb/connect/connect.ts @@ -2,11 +2,11 @@ import { z } from "zod"; import { CallToolResult } from "@modelcontextprotocol/sdk/types.js"; import { MongoDBToolBase } from "../mongodbTool.js"; import { ToolArgs, OperationType } from "../../tool.js"; -import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; import assert from "assert"; import { UserConfig } from "../../../config.js"; import { Telemetry } from "../../../telemetry/telemetry.js"; import { Session } from "../../../session.js"; +import { Server } from "../../../server.js"; const disconnectedSchema = z .object({ @@ -33,7 +33,7 @@ const connectedDescription = const disconnectedDescription = "Connect to a MongoDB instance"; export class ConnectTool extends MongoDBToolBase { - protected name: typeof connectedName | typeof disconnectedName = disconnectedName; + public name: typeof connectedName | typeof disconnectedName = disconnectedName; protected description: typeof connectedDescription | typeof disconnectedDescription = disconnectedDescription; // Here the default is empty just to trigger registration, but we're going to override it with the correct @@ -42,7 +42,7 @@ export class ConnectTool extends MongoDBToolBase { connectionString: z.string().optional(), }; - protected operationType: OperationType = "metadata"; + public operationType: OperationType = "connect"; constructor(session: Session, config: UserConfig, telemetry: Telemetry) { super(session, config, telemetry); @@ -72,10 +72,13 @@ export class ConnectTool extends MongoDBToolBase { }; } - public register(server: McpServer): void { - super.register(server); + public register(server: Server): boolean { + if (super.register(server)) { + this.updateMetadata(); + return true; + } - this.updateMetadata(); + return false; } private updateMetadata(): void { diff --git a/src/tools/mongodb/create/createCollection.ts b/src/tools/mongodb/create/createCollection.ts index 27eaa9f59..0b1c65a7b 100644 --- a/src/tools/mongodb/create/createCollection.ts +++ b/src/tools/mongodb/create/createCollection.ts @@ -3,12 +3,12 @@ import { DbOperationArgs, MongoDBToolBase } from "../mongodbTool.js"; import { OperationType, ToolArgs } from "../../tool.js"; export class CreateCollectionTool extends MongoDBToolBase { - protected name = "create-collection"; + public name = "create-collection"; protected description = "Creates a new collection in a database. If the database doesn't exist, it will be created automatically."; protected argsShape = DbOperationArgs; - protected operationType: OperationType = "create"; + public operationType: OperationType = "create"; protected async execute({ collection, database }: ToolArgs): Promise { const provider = await this.ensureConnected(); diff --git a/src/tools/mongodb/create/createIndex.ts b/src/tools/mongodb/create/createIndex.ts index beffaf864..8e393f04a 100644 --- a/src/tools/mongodb/create/createIndex.ts +++ b/src/tools/mongodb/create/createIndex.ts @@ -5,7 +5,7 @@ import { ToolArgs, OperationType } from "../../tool.js"; import { IndexDirection } from "mongodb"; export class CreateIndexTool extends MongoDBToolBase { - protected name = "create-index"; + public name = "create-index"; protected description = "Create an index for a collection"; protected argsShape = { ...DbOperationArgs, @@ -13,7 +13,7 @@ export class CreateIndexTool extends MongoDBToolBase { name: z.string().optional().describe("The name of the index"), }; - protected operationType: OperationType = "create"; + public operationType: OperationType = "create"; protected async execute({ database, diff --git a/src/tools/mongodb/create/insertMany.ts b/src/tools/mongodb/create/insertMany.ts index f28d79d5d..4744e344a 100644 --- a/src/tools/mongodb/create/insertMany.ts +++ b/src/tools/mongodb/create/insertMany.ts @@ -4,7 +4,7 @@ import { DbOperationArgs, MongoDBToolBase } from "../mongodbTool.js"; import { ToolArgs, OperationType } from "../../tool.js"; export class InsertManyTool extends MongoDBToolBase { - protected name = "insert-many"; + public name = "insert-many"; protected description = "Insert an array of documents into a MongoDB collection"; protected argsShape = { ...DbOperationArgs, @@ -14,7 +14,7 @@ export class InsertManyTool extends MongoDBToolBase { "The array of documents to insert, matching the syntax of the document argument of db.collection.insertMany()" ), }; - protected operationType: OperationType = "create"; + public operationType: OperationType = "create"; protected async execute({ database, diff --git a/src/tools/mongodb/delete/deleteMany.ts b/src/tools/mongodb/delete/deleteMany.ts index 0257d1676..aa1355127 100644 --- a/src/tools/mongodb/delete/deleteMany.ts +++ b/src/tools/mongodb/delete/deleteMany.ts @@ -5,7 +5,7 @@ import { ToolArgs, OperationType } from "../../tool.js"; import { checkIndexUsage } from "../../../helpers/indexCheck.js"; export class DeleteManyTool extends MongoDBToolBase { - protected name = "delete-many"; + public name = "delete-many"; protected description = "Removes all documents that match the filter from a MongoDB collection"; protected argsShape = { ...DbOperationArgs, @@ -16,7 +16,7 @@ export class DeleteManyTool extends MongoDBToolBase { "The query filter, specifying the deletion criteria. Matches the syntax of the filter argument of db.collection.deleteMany()" ), }; - protected operationType: OperationType = "delete"; + public operationType: OperationType = "delete"; protected async execute({ database, diff --git a/src/tools/mongodb/delete/dropCollection.ts b/src/tools/mongodb/delete/dropCollection.ts index ac914f75d..f555df048 100644 --- a/src/tools/mongodb/delete/dropCollection.ts +++ b/src/tools/mongodb/delete/dropCollection.ts @@ -3,13 +3,13 @@ import { DbOperationArgs, MongoDBToolBase } from "../mongodbTool.js"; import { ToolArgs, OperationType } from "../../tool.js"; export class DropCollectionTool extends MongoDBToolBase { - protected name = "drop-collection"; + public name = "drop-collection"; protected description = "Removes a collection or view from the database. The method also removes any indexes associated with the dropped collection."; protected argsShape = { ...DbOperationArgs, }; - protected operationType: OperationType = "delete"; + public operationType: OperationType = "delete"; protected async execute({ database, collection }: ToolArgs): Promise { const provider = await this.ensureConnected(); diff --git a/src/tools/mongodb/delete/dropDatabase.ts b/src/tools/mongodb/delete/dropDatabase.ts index b10862b20..019672659 100644 --- a/src/tools/mongodb/delete/dropDatabase.ts +++ b/src/tools/mongodb/delete/dropDatabase.ts @@ -3,12 +3,12 @@ import { DbOperationArgs, MongoDBToolBase } from "../mongodbTool.js"; import { ToolArgs, OperationType } from "../../tool.js"; export class DropDatabaseTool extends MongoDBToolBase { - protected name = "drop-database"; + public name = "drop-database"; protected description = "Removes the specified database, deleting the associated data files"; protected argsShape = { database: DbOperationArgs.database, }; - protected operationType: OperationType = "delete"; + public operationType: OperationType = "delete"; protected async execute({ database }: ToolArgs): Promise { const provider = await this.ensureConnected(); diff --git a/src/tools/mongodb/metadata/collectionSchema.ts b/src/tools/mongodb/metadata/collectionSchema.ts index f01453232..693b8f916 100644 --- a/src/tools/mongodb/metadata/collectionSchema.ts +++ b/src/tools/mongodb/metadata/collectionSchema.ts @@ -4,11 +4,11 @@ import { ToolArgs, OperationType } from "../../tool.js"; import { getSimplifiedSchema } from "mongodb-schema"; export class CollectionSchemaTool extends MongoDBToolBase { - protected name = "collection-schema"; + public name = "collection-schema"; protected description = "Describe the schema for a collection"; protected argsShape = DbOperationArgs; - protected operationType: OperationType = "metadata"; + public operationType: OperationType = "metadata"; protected async execute({ database, collection }: ToolArgs): Promise { const provider = await this.ensureConnected(); diff --git a/src/tools/mongodb/metadata/collectionStorageSize.ts b/src/tools/mongodb/metadata/collectionStorageSize.ts index 127e7172d..7a37499aa 100644 --- a/src/tools/mongodb/metadata/collectionStorageSize.ts +++ b/src/tools/mongodb/metadata/collectionStorageSize.ts @@ -3,11 +3,11 @@ import { DbOperationArgs, MongoDBToolBase } from "../mongodbTool.js"; import { ToolArgs, OperationType } from "../../tool.js"; export class CollectionStorageSizeTool extends MongoDBToolBase { - protected name = "collection-storage-size"; + public name = "collection-storage-size"; protected description = "Gets the size of the collection"; protected argsShape = DbOperationArgs; - protected operationType: OperationType = "metadata"; + public operationType: OperationType = "metadata"; protected async execute({ database, collection }: ToolArgs): Promise { const provider = await this.ensureConnected(); diff --git a/src/tools/mongodb/metadata/dbStats.ts b/src/tools/mongodb/metadata/dbStats.ts index a8c0ea0d9..ee819c556 100644 --- a/src/tools/mongodb/metadata/dbStats.ts +++ b/src/tools/mongodb/metadata/dbStats.ts @@ -4,13 +4,13 @@ import { ToolArgs, OperationType } from "../../tool.js"; import { EJSON } from "bson"; export class DbStatsTool extends MongoDBToolBase { - protected name = "db-stats"; + public name = "db-stats"; protected description = "Returns statistics that reflect the use state of a single database"; protected argsShape = { database: DbOperationArgs.database, }; - protected operationType: OperationType = "metadata"; + public operationType: OperationType = "metadata"; protected async execute({ database }: ToolArgs): Promise { const provider = await this.ensureConnected(); diff --git a/src/tools/mongodb/metadata/explain.ts b/src/tools/mongodb/metadata/explain.ts index 1068a0084..a686d9cce 100644 --- a/src/tools/mongodb/metadata/explain.ts +++ b/src/tools/mongodb/metadata/explain.ts @@ -8,7 +8,7 @@ import { FindArgs } from "../read/find.js"; import { CountArgs } from "../read/count.js"; export class ExplainTool extends MongoDBToolBase { - protected name = "explain"; + public name = "explain"; protected description = "Returns statistics describing the execution of the winning plan chosen by the query optimizer for the evaluated method"; @@ -34,7 +34,7 @@ export class ExplainTool extends MongoDBToolBase { .describe("The method and its arguments to run"), }; - protected operationType: OperationType = "metadata"; + public operationType: OperationType = "metadata"; static readonly defaultVerbosity = ExplainVerbosity.queryPlanner; diff --git a/src/tools/mongodb/metadata/listCollections.ts b/src/tools/mongodb/metadata/listCollections.ts index 193d0465c..9611d5419 100644 --- a/src/tools/mongodb/metadata/listCollections.ts +++ b/src/tools/mongodb/metadata/listCollections.ts @@ -3,13 +3,13 @@ import { DbOperationArgs, MongoDBToolBase } from "../mongodbTool.js"; import { ToolArgs, OperationType } from "../../tool.js"; export class ListCollectionsTool extends MongoDBToolBase { - protected name = "list-collections"; + public name = "list-collections"; protected description = "List all collections for a given database"; protected argsShape = { database: DbOperationArgs.database, }; - protected operationType: OperationType = "metadata"; + public operationType: OperationType = "metadata"; protected async execute({ database }: ToolArgs): Promise { const provider = await this.ensureConnected(); diff --git a/src/tools/mongodb/metadata/listDatabases.ts b/src/tools/mongodb/metadata/listDatabases.ts index fe324f07f..400f275ba 100644 --- a/src/tools/mongodb/metadata/listDatabases.ts +++ b/src/tools/mongodb/metadata/listDatabases.ts @@ -4,10 +4,10 @@ import * as bson from "bson"; import { OperationType } from "../../tool.js"; export class ListDatabasesTool extends MongoDBToolBase { - protected name = "list-databases"; + public name = "list-databases"; protected description = "List all databases for a MongoDB connection"; protected argsShape = {}; - protected operationType: OperationType = "metadata"; + public operationType: OperationType = "metadata"; protected async execute(): Promise { const provider = await this.ensureConnected(); diff --git a/src/tools/mongodb/metadata/logs.ts b/src/tools/mongodb/metadata/logs.ts index 9056aa590..899738fd1 100644 --- a/src/tools/mongodb/metadata/logs.ts +++ b/src/tools/mongodb/metadata/logs.ts @@ -4,7 +4,7 @@ import { ToolArgs, OperationType } from "../../tool.js"; import { z } from "zod"; export class LogsTool extends MongoDBToolBase { - protected name = "mongodb-logs"; + public name = "mongodb-logs"; protected description = "Returns the most recent logged mongod events"; protected argsShape = { type: z @@ -24,7 +24,7 @@ export class LogsTool extends MongoDBToolBase { .describe("The maximum number of log entries to return."), }; - protected operationType: OperationType = "metadata"; + public operationType: OperationType = "metadata"; protected async execute({ type, limit }: ToolArgs): Promise { const provider = await this.ensureConnected(); diff --git a/src/tools/mongodb/mongodbTool.ts b/src/tools/mongodb/mongodbTool.ts index fe996a381..2e5c68c7f 100644 --- a/src/tools/mongodb/mongodbTool.ts +++ b/src/tools/mongodb/mongodbTool.ts @@ -4,6 +4,7 @@ import { NodeDriverServiceProvider } from "@mongosh/service-provider-node-driver import { CallToolResult } from "@modelcontextprotocol/sdk/types.js"; import { ErrorCodes, MongoDBError } from "../../errors.js"; import logger, { LogId } from "../../logger.js"; +import { Server } from "../../server.js"; export const DbOperationArgs = { database: z.string().describe("Database name"), @@ -11,7 +12,8 @@ export const DbOperationArgs = { }; export abstract class MongoDBToolBase extends ToolBase { - protected category: ToolCategory = "mongodb"; + private server?: Server; + public category: ToolCategory = "mongodb"; protected async ensureConnected(): Promise { if (!this.session.serviceProvider) { @@ -43,11 +45,28 @@ export abstract class MongoDBToolBase extends ToolBase { return this.session.serviceProvider; } + public register(server: Server): boolean { + this.server = server; + return super.register(server); + } + protected handleError( error: unknown, args: ToolArgs ): Promise | CallToolResult { if (error instanceof MongoDBError) { + const connectTools = this.server?.tools + .filter((t) => t.operationType === "connect") + .sort((a, b) => a.category.localeCompare(b.category)); // Sort Altas tools before MongoDB tools + + // Find the first Atlas connect tool if available and suggest to the LLM to use it. + // Note: if we ever have multiple Atlas connect tools, we may want to refine this logic to select the most appropriate one. + const atlasConnectTool = connectTools?.find((t) => t.category === "atlas"); + const llmConnectHint = atlasConnectTool + ? `Note to LLM: prefer using the "${atlasConnectTool.name}" tool to connect to an Atlas cluster over using a connection string. Make sure to ask the user to specify a cluster name they want to connect to or ask them if they want to use the "list-clusters" tool to list all their clusters. Do not invent cluster names or connection strings unless the user has explicitly specified them. If they've previously connected to MongoDB using MCP, you can ask them if they want to reconnect using the same cluster/connection.` + : "Note to LLM: do not invent connection strings and explicitly ask the user to provide one. If they have previously connected to MongoDB using MCP, you can ask them if they want to reconnect using the same connection string."; + + const connectToolsNames = connectTools?.map((t) => `"${t.name}"`).join(", "); switch (error.code) { case ErrorCodes.NotConnectedToMongoDB: return { @@ -58,7 +77,9 @@ export abstract class MongoDBToolBase extends ToolBase { }, { type: "text", - text: "Please use the 'connect' or 'switch-connection' tool to connect to a MongoDB instance.", + text: connectToolsNames + ? `Please use one of the following tools: ${connectToolsNames} to connect to a MongoDB instance or update the MCP server configuration to include a connection string. ${llmConnectHint}` + : "There are no tools available to connect. Please update the configuration to include a connection string and restart the server.", }, ], isError: true, @@ -68,7 +89,13 @@ export abstract class MongoDBToolBase extends ToolBase { content: [ { type: "text", - text: "The configured connection string is not valid. Please check the connection string and confirm it points to a valid MongoDB instance. Alternatively, use the 'switch-connection' tool to connect to a different instance.", + text: "The configured connection string is not valid. Please check the connection string and confirm it points to a valid MongoDB instance.", + }, + { + type: "text", + text: connectTools + ? `Alternatively, you can use one of the following tools: ${connectToolsNames} to connect to a MongoDB instance. ${llmConnectHint}` + : "Please update the configuration to use a valid connection string and restart the server.", }, ], isError: true, diff --git a/src/tools/mongodb/read/aggregate.ts b/src/tools/mongodb/read/aggregate.ts index aa21fc5d5..f9868dba8 100644 --- a/src/tools/mongodb/read/aggregate.ts +++ b/src/tools/mongodb/read/aggregate.ts @@ -10,13 +10,13 @@ export const AggregateArgs = { }; export class AggregateTool extends MongoDBToolBase { - protected name = "aggregate"; + public name = "aggregate"; protected description = "Run an aggregation against a MongoDB collection"; protected argsShape = { ...DbOperationArgs, ...AggregateArgs, }; - protected operationType: OperationType = "read"; + public operationType: OperationType = "read"; protected async execute({ database, diff --git a/src/tools/mongodb/read/collectionIndexes.ts b/src/tools/mongodb/read/collectionIndexes.ts index cc0a141bc..ef3fa75df 100644 --- a/src/tools/mongodb/read/collectionIndexes.ts +++ b/src/tools/mongodb/read/collectionIndexes.ts @@ -3,10 +3,10 @@ import { DbOperationArgs, MongoDBToolBase } from "../mongodbTool.js"; import { ToolArgs, OperationType } from "../../tool.js"; export class CollectionIndexesTool extends MongoDBToolBase { - protected name = "collection-indexes"; + public name = "collection-indexes"; protected description = "Describe the indexes for a collection"; protected argsShape = DbOperationArgs; - protected operationType: OperationType = "read"; + public operationType: OperationType = "read"; protected async execute({ database, collection }: ToolArgs): Promise { const provider = await this.ensureConnected(); diff --git a/src/tools/mongodb/read/count.ts b/src/tools/mongodb/read/count.ts index 0ed3a1924..df3664b57 100644 --- a/src/tools/mongodb/read/count.ts +++ b/src/tools/mongodb/read/count.ts @@ -14,7 +14,7 @@ export const CountArgs = { }; export class CountTool extends MongoDBToolBase { - protected name = "count"; + public name = "count"; protected description = "Gets the number of documents in a MongoDB collection using db.collection.count() and query as an optional filter parameter"; protected argsShape = { @@ -22,7 +22,7 @@ export class CountTool extends MongoDBToolBase { ...CountArgs, }; - protected operationType: OperationType = "read"; + public operationType: OperationType = "read"; protected async execute({ database, collection, query }: ToolArgs): Promise { const provider = await this.ensureConnected(); diff --git a/src/tools/mongodb/read/find.ts b/src/tools/mongodb/read/find.ts index 97c90e08e..02c337edb 100644 --- a/src/tools/mongodb/read/find.ts +++ b/src/tools/mongodb/read/find.ts @@ -23,13 +23,13 @@ export const FindArgs = { }; export class FindTool extends MongoDBToolBase { - protected name = "find"; + public name = "find"; protected description = "Run a find query against a MongoDB collection"; protected argsShape = { ...DbOperationArgs, ...FindArgs, }; - protected operationType: OperationType = "read"; + public operationType: OperationType = "read"; protected async execute({ database, diff --git a/src/tools/mongodb/tools.ts b/src/tools/mongodb/tools.ts index d64d53ea7..c74fdf294 100644 --- a/src/tools/mongodb/tools.ts +++ b/src/tools/mongodb/tools.ts @@ -1,4 +1,4 @@ -import { ConnectTool } from "./metadata/connect.js"; +import { ConnectTool } from "./connect/connect.js"; import { ListCollectionsTool } from "./metadata/listCollections.js"; import { CollectionIndexesTool } from "./read/collectionIndexes.js"; import { ListDatabasesTool } from "./metadata/listDatabases.js"; diff --git a/src/tools/mongodb/update/renameCollection.ts b/src/tools/mongodb/update/renameCollection.ts index d3b07c157..e5bffbdb4 100644 --- a/src/tools/mongodb/update/renameCollection.ts +++ b/src/tools/mongodb/update/renameCollection.ts @@ -4,14 +4,14 @@ import { DbOperationArgs, MongoDBToolBase } from "../mongodbTool.js"; import { ToolArgs, OperationType } from "../../tool.js"; export class RenameCollectionTool extends MongoDBToolBase { - protected name = "rename-collection"; + public name = "rename-collection"; protected description = "Renames a collection in a MongoDB database"; protected argsShape = { ...DbOperationArgs, newName: z.string().describe("The new name for the collection"), dropTarget: z.boolean().optional().default(false).describe("If true, drops the target collection if it exists"), }; - protected operationType: OperationType = "update"; + public operationType: OperationType = "update"; protected async execute({ database, diff --git a/src/tools/mongodb/update/updateMany.ts b/src/tools/mongodb/update/updateMany.ts index 7392135b6..b31a843e6 100644 --- a/src/tools/mongodb/update/updateMany.ts +++ b/src/tools/mongodb/update/updateMany.ts @@ -5,7 +5,7 @@ import { ToolArgs, OperationType } from "../../tool.js"; import { checkIndexUsage } from "../../../helpers/indexCheck.js"; export class UpdateManyTool extends MongoDBToolBase { - protected name = "update-many"; + public name = "update-many"; protected description = "Updates all documents that match the specified filter for a collection"; protected argsShape = { ...DbOperationArgs, @@ -23,7 +23,7 @@ export class UpdateManyTool extends MongoDBToolBase { .optional() .describe("Controls whether to insert a new document if no documents match the filter"), }; - protected operationType: OperationType = "update"; + public operationType: OperationType = "update"; protected async execute({ database, diff --git a/src/tools/tool.ts b/src/tools/tool.ts index b7cce3547..551374d66 100644 --- a/src/tools/tool.ts +++ b/src/tools/tool.ts @@ -1,15 +1,16 @@ import { z, type ZodRawShape, type ZodNever, AnyZodObject } from "zod"; -import type { McpServer, RegisteredTool, ToolCallback } from "@modelcontextprotocol/sdk/server/mcp.js"; +import type { RegisteredTool, ToolCallback } from "@modelcontextprotocol/sdk/server/mcp.js"; import type { CallToolResult, ToolAnnotations } from "@modelcontextprotocol/sdk/types.js"; import { Session } from "../session.js"; import logger, { LogId } from "../logger.js"; import { Telemetry } from "../telemetry/telemetry.js"; import { type ToolEvent } from "../telemetry/types.js"; import { UserConfig } from "../config.js"; +import { Server } from "../server.js"; export type ToolArgs = z.objectOutputType; -export type OperationType = "metadata" | "read" | "create" | "delete" | "update"; +export type OperationType = "metadata" | "read" | "create" | "delete" | "update" | "connect"; export type ToolCategory = "mongodb" | "atlas"; export type TelemetryToolMetadata = { projectId?: string; @@ -17,11 +18,11 @@ export type TelemetryToolMetadata = { }; export abstract class ToolBase { - protected abstract name: string; + public abstract name: string; - protected abstract category: ToolCategory; + public abstract category: ToolCategory; - protected abstract operationType: OperationType; + public abstract operationType: OperationType; protected abstract description: string; @@ -36,6 +37,7 @@ export abstract class ToolBase { switch (this.operationType) { case "read": case "metadata": + case "connect": annotations.readOnlyHint = true; annotations.destructiveHint = false; break; @@ -63,9 +65,9 @@ export abstract class ToolBase { protected readonly telemetry: Telemetry ) {} - public register(server: McpServer): void { + public register(server: Server): boolean { if (!this.verifyAllowed()) { - return; + return false; } const callback: ToolCallback = async (...args) => { @@ -84,14 +86,15 @@ export abstract class ToolBase { } }; - server.tool(this.name, this.description, this.argsShape, this.annotations, callback); + server.mcpServer.tool(this.name, this.description, this.argsShape, this.annotations, callback); // This is very similar to RegisteredTool.update, but without the bugs around the name. // In the upstream update method, the name is captured in the closure and not updated when // the tool name changes. This means that you only get one name update before things end up // in a broken state. + // See https://github.com/modelcontextprotocol/typescript-sdk/issues/414 for more details. this.update = (updates: { name?: string; description?: string; inputSchema?: AnyZodObject }) => { - const tools = server["_registeredTools"] as { [toolName: string]: RegisteredTool }; + const tools = server.mcpServer["_registeredTools"] as { [toolName: string]: RegisteredTool }; const existingTool = tools[this.name]; if (!existingTool) { @@ -118,8 +121,10 @@ export abstract class ToolBase { existingTool.inputSchema = updates.inputSchema; } - server.sendToolListChanged(); + server.mcpServer.sendToolListChanged(); }; + + return true; } protected update?: (updates: { name?: string; description?: string; inputSchema?: AnyZodObject }) => void; diff --git a/tests/integration/tools/atlas/clusters.test.ts b/tests/integration/tools/atlas/clusters.test.ts index 62bd422c4..b5f34bdfb 100644 --- a/tests/integration/tools/atlas/clusters.test.ts +++ b/tests/integration/tools/atlas/clusters.test.ts @@ -1,5 +1,5 @@ import { Session } from "../../../../src/session.js"; -import { expectDefined } from "../../helpers.js"; +import { expectDefined, getResponseElements } from "../../helpers.js"; import { describeWithAtlas, withProject, randomId } from "./atlasHelpers.js"; import { ClusterDescription20240805 } from "../../../../src/common/atlas/openapi.js"; import { CallToolResult } from "@modelcontextprotocol/sdk/types.js"; @@ -205,6 +205,23 @@ describeWithAtlas("clusters", (integration) => { await sleep(500); } }); + + describe("when not connected", () => { + it("prompts for atlas-connect-cluster when querying mongodb", async () => { + const response = await integration.mcpClient().callTool({ + name: "find", + arguments: { database: "some-db", collection: "some-collection" }, + }); + const elements = getResponseElements(response.content); + expect(elements).toHaveLength(2); + expect(elements[0]?.text).toContain( + "You need to connect to a MongoDB instance before you can access its data." + ); + expect(elements[1]?.text).toContain( + 'Please use one of the following tools: "atlas-connect-cluster", "connect" to connect to a MongoDB instance' + ); + }); + }); }); }); }); diff --git a/tests/integration/tools/mongodb/metadata/connect.test.ts b/tests/integration/tools/mongodb/connect/connect.test.ts similarity index 82% rename from tests/integration/tools/mongodb/metadata/connect.test.ts rename to tests/integration/tools/mongodb/connect/connect.test.ts index 47e91d131..857b57475 100644 --- a/tests/integration/tools/mongodb/metadata/connect.test.ts +++ b/tests/integration/tools/mongodb/connect/connect.test.ts @@ -1,9 +1,15 @@ import { describeWithMongoDB } from "../mongodbHelpers.js"; -import { getResponseContent, validateThrowsForInvalidArguments, validateToolMetadata } from "../../../helpers.js"; +import { + getResponseContent, + getResponseElements, + validateThrowsForInvalidArguments, + validateToolMetadata, +} from "../../../helpers.js"; import { config } from "../../../../../src/config.js"; +import { defaultTestConfig, setupIntegrationTest } from "../../../helpers.js"; describeWithMongoDB( - "switchConnection tool", + "SwitchConnection tool", (integration) => { beforeEach(() => { integration.mcpServer().userConfig.connectionString = integration.connectionString(); @@ -77,6 +83,7 @@ describeWithMongoDB( connectionString: mdbIntegration.connectionString(), }) ); + describeWithMongoDB( "Connect tool", (integration) => { @@ -126,3 +133,26 @@ describeWithMongoDB( }, () => config ); + +describe("Connect tool when disabled", () => { + const integration = setupIntegrationTest(() => ({ + ...defaultTestConfig, + disabledTools: ["connect"], + })); + + it("is not suggested when querying MongoDB disconnected", async () => { + const response = await integration.mcpClient().callTool({ + name: "find", + arguments: { database: "some-db", collection: "some-collection" }, + }); + + const elements = getResponseElements(response); + expect(elements).toHaveLength(2); + expect(elements[0]?.text).toContain( + "You need to connect to a MongoDB instance before you can access its data." + ); + expect(elements[1]?.text).toContain( + "There are no tools available to connect. Please update the configuration to include a connection string and restart the server." + ); + }); +});