Skip to content

Commit 1ae7f5e

Browse files
chore: allows ConnectionManager to be injected via TransportRunner MCP-131 (#481)
1 parent bfae3f7 commit 1ae7f5e

File tree

16 files changed

+203
-119
lines changed

16 files changed

+203
-119
lines changed

eslint-rules/no-config-imports.js

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ const allowedConfigValueImportFiles = [
1010
"src/index.ts",
1111
// Config resource definition that works with the some config values
1212
"src/resources/common/config.ts",
13+
// The file exports, a factory function to create MCPConnectionManager and
14+
// it relies on driver options generator and default driver options from
15+
// config file.
16+
"src/common/connectionManager.ts",
1317
];
1418

1519
// Ref: https://eslint.org/docs/latest/extend/custom-rules

src/common/connectionManager.ts

Lines changed: 103 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,14 @@
1-
import type { UserConfig, DriverOptions } from "./config.js";
2-
import { NodeDriverServiceProvider } from "@mongosh/service-provider-node-driver";
3-
import EventEmitter from "events";
4-
import { setAppNameParamIfMissing } from "../helpers/connectionOptions.js";
5-
import { packageInfo } from "./packageInfo.js";
6-
import ConnectionString from "mongodb-connection-string-url";
1+
import { EventEmitter } from "events";
72
import type { MongoClientOptions } from "mongodb";
8-
import { ErrorCodes, MongoDBError } from "./errors.js";
3+
import ConnectionString from "mongodb-connection-string-url";
4+
import { NodeDriverServiceProvider } from "@mongosh/service-provider-node-driver";
5+
import { type ConnectionInfo, generateConnectionInfoFromCliArgs } from "@mongosh/arg-parser";
96
import type { DeviceId } from "../helpers/deviceId.js";
10-
import type { AppNameComponents } from "../helpers/connectionOptions.js";
11-
import type { CompositeLogger } from "./logger.js";
12-
import { LogId } from "./logger.js";
13-
import type { ConnectionInfo } from "@mongosh/arg-parser";
14-
import { generateConnectionInfoFromCliArgs } from "@mongosh/arg-parser";
7+
import { defaultDriverOptions, setupDriverConfig, type DriverOptions, type UserConfig } from "./config.js";
8+
import { MongoDBError, ErrorCodes } from "./errors.js";
9+
import { type LoggerBase, LogId } from "./logger.js";
10+
import { packageInfo } from "./packageInfo.js";
11+
import { type AppNameComponents, setAppNameParamIfMissing } from "../helpers/connectionOptions.js";
1512

1613
export interface AtlasClusterConnectionInfo {
1714
username: string;
@@ -71,39 +68,76 @@ export interface ConnectionManagerEvents {
7168
"connection-error": [ConnectionStateErrored];
7269
}
7370

74-
export class ConnectionManager extends EventEmitter<ConnectionManagerEvents> {
71+
/**
72+
* For a few tests, we need the changeState method to force a connection state
73+
* which is we have this type to typecast the actual ConnectionManager with
74+
* public changeState (only to make TS happy).
75+
*/
76+
export type TestConnectionManager = ConnectionManager & {
77+
changeState<Event extends keyof ConnectionManagerEvents, State extends ConnectionManagerEvents[Event][0]>(
78+
event: Event,
79+
newState: State
80+
): State;
81+
};
82+
83+
export abstract class ConnectionManager {
84+
protected clientName: string;
85+
protected readonly _events;
86+
readonly events: Pick<EventEmitter<ConnectionManagerEvents>, "on" | "off" | "once">;
7587
private state: AnyConnectionState;
88+
89+
constructor() {
90+
this.clientName = "unknown";
91+
this.events = this._events = new EventEmitter<ConnectionManagerEvents>();
92+
this.state = { tag: "disconnected" };
93+
}
94+
95+
get currentConnectionState(): AnyConnectionState {
96+
return this.state;
97+
}
98+
99+
protected changeState<Event extends keyof ConnectionManagerEvents, State extends ConnectionManagerEvents[Event][0]>(
100+
event: Event,
101+
newState: State
102+
): State {
103+
this.state = newState;
104+
// TypeScript doesn't seem to be happy with the spread operator and generics
105+
// eslint-disable-next-line
106+
this._events.emit(event, ...([newState] as any));
107+
return newState;
108+
}
109+
110+
setClientName(clientName: string): void {
111+
this.clientName = clientName;
112+
}
113+
114+
abstract connect(settings: ConnectionSettings): Promise<AnyConnectionState>;
115+
116+
abstract disconnect(): Promise<ConnectionStateDisconnected | ConnectionStateErrored>;
117+
}
118+
119+
export class MCPConnectionManager extends ConnectionManager {
76120
private deviceId: DeviceId;
77-
private clientName: string;
78121
private bus: EventEmitter;
79122

80123
constructor(
81124
private userConfig: UserConfig,
82125
private driverOptions: DriverOptions,
83-
private logger: CompositeLogger,
126+
private logger: LoggerBase,
84127
deviceId: DeviceId,
85128
bus?: EventEmitter
86129
) {
87130
super();
88-
89131
this.bus = bus ?? new EventEmitter();
90-
this.state = { tag: "disconnected" };
91-
92132
this.bus.on("mongodb-oidc-plugin:auth-failed", this.onOidcAuthFailed.bind(this));
93133
this.bus.on("mongodb-oidc-plugin:auth-succeeded", this.onOidcAuthSucceeded.bind(this));
94-
95134
this.deviceId = deviceId;
96-
this.clientName = "unknown";
97-
}
98-
99-
setClientName(clientName: string): void {
100-
this.clientName = clientName;
101135
}
102136

103137
async connect(settings: ConnectionSettings): Promise<AnyConnectionState> {
104-
this.emit("connection-request", this.state);
138+
this._events.emit("connection-request", this.currentConnectionState);
105139

106-
if (this.state.tag === "connected" || this.state.tag === "connecting") {
140+
if (this.currentConnectionState.tag === "connected" || this.currentConnectionState.tag === "connecting") {
107141
await this.disconnect();
108142
}
109143

@@ -138,7 +172,7 @@ export class ConnectionManager extends EventEmitter<ConnectionManagerEvents> {
138172
connectionInfo.driverOptions.proxy ??= { useEnvironmentVariableProxies: true };
139173
connectionInfo.driverOptions.applyProxyToOIDC ??= true;
140174

141-
connectionStringAuthType = ConnectionManager.inferConnectionTypeFromSettings(
175+
connectionStringAuthType = MCPConnectionManager.inferConnectionTypeFromSettings(
142176
this.userConfig,
143177
connectionInfo
144178
);
@@ -165,7 +199,10 @@ export class ConnectionManager extends EventEmitter<ConnectionManagerEvents> {
165199
}
166200

167201
try {
168-
const connectionType = ConnectionManager.inferConnectionTypeFromSettings(this.userConfig, connectionInfo);
202+
const connectionType = MCPConnectionManager.inferConnectionTypeFromSettings(
203+
this.userConfig,
204+
connectionInfo
205+
);
169206
if (connectionType.startsWith("oidc")) {
170207
void this.pingAndForget(serviceProvider);
171208

@@ -199,13 +236,13 @@ export class ConnectionManager extends EventEmitter<ConnectionManagerEvents> {
199236
}
200237

201238
async disconnect(): Promise<ConnectionStateDisconnected | ConnectionStateErrored> {
202-
if (this.state.tag === "disconnected" || this.state.tag === "errored") {
203-
return this.state;
239+
if (this.currentConnectionState.tag === "disconnected" || this.currentConnectionState.tag === "errored") {
240+
return this.currentConnectionState;
204241
}
205242

206-
if (this.state.tag === "connected" || this.state.tag === "connecting") {
243+
if (this.currentConnectionState.tag === "connected" || this.currentConnectionState.tag === "connecting") {
207244
try {
208-
await this.state.serviceProvider?.close(true);
245+
await this.currentConnectionState.serviceProvider?.close(true);
209246
} finally {
210247
this.changeState("connection-close", {
211248
tag: "disconnected",
@@ -216,30 +253,21 @@ export class ConnectionManager extends EventEmitter<ConnectionManagerEvents> {
216253
return { tag: "disconnected" };
217254
}
218255

219-
get currentConnectionState(): AnyConnectionState {
220-
return this.state;
221-
}
222-
223-
changeState<Event extends keyof ConnectionManagerEvents, State extends ConnectionManagerEvents[Event][0]>(
224-
event: Event,
225-
newState: State
226-
): State {
227-
this.state = newState;
228-
// TypeScript doesn't seem to be happy with the spread operator and generics
229-
// eslint-disable-next-line
230-
this.emit(event, ...([newState] as any));
231-
return newState;
232-
}
233-
234256
private onOidcAuthFailed(error: unknown): void {
235-
if (this.state.tag === "connecting" && this.state.connectionStringAuthType?.startsWith("oidc")) {
257+
if (
258+
this.currentConnectionState.tag === "connecting" &&
259+
this.currentConnectionState.connectionStringAuthType?.startsWith("oidc")
260+
) {
236261
void this.disconnectOnOidcError(error);
237262
}
238263
}
239264

240265
private onOidcAuthSucceeded(): void {
241-
if (this.state.tag === "connecting" && this.state.connectionStringAuthType?.startsWith("oidc")) {
242-
this.changeState("connection-success", { ...this.state, tag: "connected" });
266+
if (
267+
this.currentConnectionState.tag === "connecting" &&
268+
this.currentConnectionState.connectionStringAuthType?.startsWith("oidc")
269+
) {
270+
this.changeState("connection-success", { ...this.currentConnectionState, tag: "connected" });
243271
}
244272

245273
this.logger.info({
@@ -250,9 +278,12 @@ export class ConnectionManager extends EventEmitter<ConnectionManagerEvents> {
250278
}
251279

252280
private onOidcNotifyDeviceFlow(flowInfo: { verificationUrl: string; userCode: string }): void {
253-
if (this.state.tag === "connecting" && this.state.connectionStringAuthType?.startsWith("oidc")) {
281+
if (
282+
this.currentConnectionState.tag === "connecting" &&
283+
this.currentConnectionState.connectionStringAuthType?.startsWith("oidc")
284+
) {
254285
this.changeState("connection-request", {
255-
...this.state,
286+
...this.currentConnectionState,
256287
tag: "connecting",
257288
connectionStringAuthType: "oidc-device-flow",
258289
oidcLoginUrl: flowInfo.verificationUrl,
@@ -329,3 +360,23 @@ export class ConnectionManager extends EventEmitter<ConnectionManagerEvents> {
329360
}
330361
}
331362
}
363+
364+
/**
365+
* Consumers of MCP server library have option to bring their own connection
366+
* management if they need to. To support that, we enable injecting connection
367+
* manager implementation through a factory function.
368+
*/
369+
export type ConnectionManagerFactoryFn = (createParams: {
370+
logger: LoggerBase;
371+
deviceId: DeviceId;
372+
userConfig: UserConfig;
373+
}) => Promise<ConnectionManager>;
374+
375+
export const createMCPConnectionManager: ConnectionManagerFactoryFn = ({ logger, deviceId, userConfig }) => {
376+
const driverOptions = setupDriverConfig({
377+
config: userConfig,
378+
defaults: defaultDriverOptions,
379+
});
380+
381+
return Promise.resolve(new MCPConnectionManager(userConfig, driverOptions, logger, deviceId));
382+
};

src/common/session.ts

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,10 @@ export class Session extends EventEmitter<SessionEvents> {
6767
this.apiClient = new ApiClient({ baseUrl: apiBaseUrl, credentials }, logger);
6868
this.exportsManager = exportsManager;
6969
this.connectionManager = connectionManager;
70-
this.connectionManager.on("connection-success", () => this.emit("connect"));
71-
this.connectionManager.on("connection-time-out", (error) => this.emit("connection-error", error));
72-
this.connectionManager.on("connection-close", () => this.emit("disconnect"));
73-
this.connectionManager.on("connection-error", (error) => this.emit("connection-error", error));
70+
this.connectionManager.events.on("connection-success", () => this.emit("connect"));
71+
this.connectionManager.events.on("connection-time-out", (error) => this.emit("connection-error", error));
72+
this.connectionManager.events.on("connection-close", () => this.emit("disconnect"));
73+
this.connectionManager.events.on("connection-error", (error) => this.emit("connection-error", error));
7474
}
7575

7676
setMcpClient(mcpClient: Implementation | undefined): void {

src/index.ts

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ function enableFipsIfRequested(): void {
3636
enableFipsIfRequested();
3737

3838
import { ConsoleLogger, LogId } from "./common/logger.js";
39-
import { config, driverOptions } from "./common/config.js";
39+
import { config } from "./common/config.js";
4040
import crypto from "crypto";
4141
import { packageInfo } from "./common/packageInfo.js";
4242
import { StdioRunner } from "./transports/stdio.js";
@@ -49,10 +49,7 @@ async function main(): Promise<void> {
4949
assertHelpMode();
5050
assertVersionMode();
5151

52-
const transportRunner =
53-
config.transport === "stdio"
54-
? new StdioRunner(config, driverOptions)
55-
: new StreamableHttpRunner(config, driverOptions);
52+
const transportRunner = config.transport === "stdio" ? new StdioRunner(config) : new StreamableHttpRunner(config);
5653
const shutdown = (): void => {
5754
transportRunner.logger.info({
5855
id: LogId.serverCloseRequested,

src/lib.ts

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,14 @@
11
export { Server, type ServerOptions } from "./server.js";
2-
export { Telemetry } from "./telemetry/telemetry.js";
32
export { Session, type SessionOptions } from "./common/session.js";
4-
export { type UserConfig, defaultUserConfig } from "./common/config.js";
3+
export { defaultUserConfig, type UserConfig } from "./common/config.js";
4+
export { LoggerBase, type LogPayload, type LoggerType, type LogLevel } from "./common/logger.js";
55
export { StreamableHttpRunner } from "./transports/streamableHttp.js";
6-
export { LoggerBase } from "./common/logger.js";
7-
export type { LogPayload, LoggerType, LogLevel } from "./common/logger.js";
6+
export {
7+
ConnectionManager,
8+
type AnyConnectionState,
9+
type ConnectionState,
10+
type ConnectionStateDisconnected,
11+
type ConnectionStateErrored,
12+
type ConnectionManagerFactoryFn,
13+
} from "./common/connectionManager.js";
14+
export { Telemetry } from "./telemetry/telemetry.js";

src/transports/base.ts

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import type { DriverOptions, UserConfig } from "../common/config.js";
1+
import type { UserConfig } from "../common/config.js";
22
import { packageInfo } from "../common/packageInfo.js";
33
import { Server } from "../server.js";
44
import { Session } from "../common/session.js";
@@ -7,16 +7,16 @@ import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js";
77
import type { LoggerBase } from "../common/logger.js";
88
import { CompositeLogger, ConsoleLogger, DiskLogger, McpLogger } from "../common/logger.js";
99
import { ExportsManager } from "../common/exportsManager.js";
10-
import { ConnectionManager } from "../common/connectionManager.js";
1110
import { DeviceId } from "../helpers/deviceId.js";
11+
import { type ConnectionManagerFactoryFn } from "../common/connectionManager.js";
1212

1313
export abstract class TransportRunnerBase {
1414
public logger: LoggerBase;
1515
public deviceId: DeviceId;
1616

1717
protected constructor(
1818
protected readonly userConfig: UserConfig,
19-
private readonly driverOptions: DriverOptions,
19+
private readonly createConnectionManager: ConnectionManagerFactoryFn,
2020
additionalLoggers: LoggerBase[]
2121
) {
2222
const loggers: LoggerBase[] = [...additionalLoggers];
@@ -38,15 +38,19 @@ export abstract class TransportRunnerBase {
3838
this.deviceId = DeviceId.create(this.logger);
3939
}
4040

41-
protected setupServer(): Server {
41+
protected async setupServer(): Promise<Server> {
4242
const mcpServer = new McpServer({
4343
name: packageInfo.mcpServerName,
4444
version: packageInfo.version,
4545
});
4646

4747
const logger = new CompositeLogger(this.logger);
4848
const exportsManager = ExportsManager.init(this.userConfig, logger);
49-
const connectionManager = new ConnectionManager(this.userConfig, this.driverOptions, logger, this.deviceId);
49+
const connectionManager = await this.createConnectionManager({
50+
logger,
51+
userConfig: this.userConfig,
52+
deviceId: this.deviceId,
53+
});
5054

5155
const session = new Session({
5256
apiBaseUrl: this.userConfig.apiBaseUrl,

src/transports/stdio.ts

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
import type { LoggerBase } from "../common/logger.js";
2-
import { LogId } from "../common/logger.js";
3-
import type { Server } from "../server.js";
4-
import { TransportRunnerBase } from "./base.js";
1+
import { EJSON } from "bson";
52
import type { JSONRPCMessage } from "@modelcontextprotocol/sdk/types.js";
63
import { JSONRPCMessageSchema } from "@modelcontextprotocol/sdk/types.js";
7-
import { EJSON } from "bson";
84
import { StdioServerTransport } from "@modelcontextprotocol/sdk/server/stdio.js";
9-
import type { DriverOptions, UserConfig } from "../common/config.js";
5+
import { type LoggerBase, LogId } from "../common/logger.js";
6+
import type { Server } from "../server.js";
7+
import { TransportRunnerBase } from "./base.js";
8+
import { type UserConfig } from "../common/config.js";
9+
import { createMCPConnectionManager, type ConnectionManagerFactoryFn } from "../common/connectionManager.js";
1010

1111
// This is almost a copy of ReadBuffer from @modelcontextprotocol/sdk
1212
// but it uses EJSON.parse instead of JSON.parse to handle BSON types
@@ -55,13 +55,17 @@ export function createStdioTransport(): StdioServerTransport {
5555
export class StdioRunner extends TransportRunnerBase {
5656
private server: Server | undefined;
5757

58-
constructor(userConfig: UserConfig, driverOptions: DriverOptions, additionalLoggers: LoggerBase[] = []) {
59-
super(userConfig, driverOptions, additionalLoggers);
58+
constructor(
59+
userConfig: UserConfig,
60+
createConnectionManager: ConnectionManagerFactoryFn = createMCPConnectionManager,
61+
additionalLoggers: LoggerBase[] = []
62+
) {
63+
super(userConfig, createConnectionManager, additionalLoggers);
6064
}
6165

6266
async start(): Promise<void> {
6367
try {
64-
this.server = this.setupServer();
68+
this.server = await this.setupServer();
6569

6670
const transport = createStdioTransport();
6771

0 commit comments

Comments
 (0)