Skip to content

Add custom tool support in pure C# #293

@Seng-Jik

Description

@Seng-Jik

This class will auto generate python code to define a MCP tool in C# for reference.
based on #261.

Users need not to write any python code, just create an subclass from AITool<TArg, TResult>.

#nullable enable

using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Reflection;
using System.Text;

using CosmosPrelude.Misc;

using MCPForUnity.Editor;
using MCPForUnity.Editor.Tools;

using Newtonsoft.Json.Linq;

using UnityEditor;

namespace CosmosEditor.Misc
{
    public abstract class AITool<TArg, TResult> : AITool
        where TArg : notnull
        where TResult : notnull
    {
        protected override sealed object Execute(JObject arg) =>
            Execute(arg
                .ToObject<TArg>()
                .ThrowIfNull("Failed to deserialize arguments."));

        protected abstract TResult Execute(TArg arg);
    }

    public abstract class AITool
    {
        public record ArgDesc(
            string argName,
            Type argType,
            string description);

        public record ArgDesc<T>(
            string argName,
            string description)
            : ArgDesc(argName, typeof(T), description);

        protected abstract string name { get; }
        protected abstract string desc { get; }
        protected abstract IReadOnlyList<ArgDesc> argDesc { get; }
        protected virtual string? resultDesc => null;
        protected abstract object Execute(JObject arg);

        readonly static IReadOnlyList<Assembly> assembliesIncludeAITools =
            typeof(AITool).Assembly.AsSingletonArray();

        #region Details

        [InitializeOnLoad]
        private class ToolInstaller
        {

            static ToolInstaller()
            {
                var pyToolsDir = SearchMCPBridgeToolsDir();
                ClearPythonCache(pyToolsDir);
                InjectInitCode(pyToolsDir);

                StringBuilder pyCodeWriter = new(capacity: 65536);

                pyCodeWriter
                    .AppendLine("from typing import Dict, Any")
                    .AppendLine("from mcp.server.fastmcp import FastMCP, Context")
                    .AppendLine("from unity_connection import get_unity_connection, async_send_command_with_retry")
                    .AppendLine("from config import config")
                    .AppendLine("import time")
                    .AppendLine("import asyncio")
                    .AppendLine();

                foreach (var asm in assembliesIncludeAITools)
                {
                    foreach (var type in asm.GetTypes())
                    {
                        if (!type.IsSubclassOf(typeof(AITool))) continue;
                        if (type.IsGenericType) continue;
                        if (type.IsAbstract) continue;

                        var aiTool = (AITool)Activator.CreateInstance(type);

                        CommandRegistry.Add(
                            aiTool.name,
                            HandlerWrapper(type, aiTool.Execute));

                        WritePythonBridge(pyCodeWriter, aiTool);
                    }
                }

                File.WriteAllBytes(
                    Path.Combine(pyToolsDir, "generated_by_cosmos_editor.g.py"),
                    Encoding.UTF8.GetBytes(pyCodeWriter.ToString()));
            }

            static void WritePythonBridge(StringBuilder pyWriter, AITool aiTool)
            {
                pyWriter.AppendLine("@mcp.tool()");

                pyWriter
                    .Append("async def ")
                    .Append(aiTool.name)
                    .Append("(ctx: Context");

                foreach (var i in aiTool.argDesc)
                    pyWriter.Append(", ").Append(i.argName);

                pyWriter.AppendLine(") -> Dict[str, Any]:");

                pyWriter
                    .Append("    \"\"\"")
                    .Append(aiTool.desc)
                    .AppendLine()
                    .AppendLine()
                    .AppendLine("    Args:");

                foreach (var i in aiTool.argDesc)
                {
                    pyWriter
                        .Append("        ")
                        .Append(i.argName)
                        .Append("(")
                        .Append(i.argType.Name)
                        .Append("): ")
                        .AppendLine(i.description);
                }

                pyWriter.AppendLine();

                if (aiTool.resultDesc != null)
                {
                    pyWriter
                        .AppendLine("    Returns:")
                        .Append("        ")
                        .AppendLine(aiTool.resultDesc)
                        .AppendLine();
                }

                pyWriter
                    .AppendLine("    \"\"\"")
                    .AppendLine()
                    .AppendLine("    params_dict = {");

                for (int i = 0; i < aiTool.argDesc.Count; ++i)
                {
                    var arg = aiTool.argDesc[i];

                    pyWriter
                        .Append("        \"")
                        .Append(arg.argName)
                        .Append("\": ")
                        .Append(arg.argName)
                        .AppendLine(i != aiTool.argDesc.Count - 1 ? "," : "");
                }

                pyWriter
                    .AppendLine("    }")
                    .AppendLine()
                    .AppendLine("    params_dict = {k: v for k, v in params_dict.items() if v is not None}")
                    .AppendLine("    loop = asyncio.get_running_loop()")
                    .AppendLine("    connection = get_unity_connection()")
                    .AppendFormat("    result = await async_send_command_with_retry(\"{0}\", params_dict, loop=loop)", aiTool.name)
                    .AppendLine()
                    .AppendLine("    return result if isinstance(result, dict) else {\"success\": False, \"message\": str(result)}")
                    .AppendLine()
                    .AppendLine();
            }

            static string SearchMCPBridgeToolsDir()
            {
                // for windows
                var localAppData = System.Environment.GetFolderPath(
                    Environment.SpecialFolder.LocalApplicationData);

                var toolsDir = Path.Combine(
                    localAppData,
                    "UnityMCP",
                    "UnityMcpServer",
                    "src",
                    "tools");

                if (File.Exists(Path.Combine(toolsDir, "manage_gameobject.py")))
                    return toolsDir;

                throw new InvalidOperationException(
                    "未能找到MCPBridge的tools目录,请确定MCPBridge已正确安装。");
            }

            static void ClearPythonCache(string pyDir)
            {
                var path = Path.Combine(pyDir, "__pycache__");
                if (Directory.Exists(path))
                    Directory.Delete(path, true);
            }

            static void InjectInitCode(string pyDir)
            {
                var initPyScript = Path.Combine(pyDir, "__init__.py");
                var lines = File.ReadAllLines(initPyScript).ToList();

                const string injectStr =
                    "exec(open(\"./tools/generated_by_cosmos_editor.g.py\").read(), { \"mcp\": mcp })";

                if (!lines.Any(x => x.Trim() == injectStr))
                {
                    lines.Add("    " + injectStr);
                    File.WriteAllLines(initPyScript, lines, Encoding.UTF8);
                }
            }

            static Func<JObject, object> HandlerWrapper(
                Type aiToolType,
                Func<JObject, object> handler)
            {
                object f(JObject arg)
                {
                    try
                    {
                        var resp = handler(arg);

                        if (resp is UnitType)
                            resp = "success";

                        return new
                        {
                            success = true,
                            message = resp
                        };
                    }
                    catch (Exception ex)
                    {
                        return new
                        {
                            success = false,
                            message = ex.Message,
                            csharpToolType = aiToolType,
                            exception = ex.ToString()
                        };
                    }
                }

                return f;
            }
        }

        #endregion
    }
}

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions