diff --git a/litellm/litellm_core_utils/prompt_templates/factory.py b/litellm/litellm_core_utils/prompt_templates/factory.py index 77cbe4c9a8e7..da9899cf5e39 100644 --- a/litellm/litellm_core_utils/prompt_templates/factory.py +++ b/litellm/litellm_core_utils/prompt_templates/factory.py @@ -40,6 +40,7 @@ from .common_utils import convert_content_list_to_str, is_non_content_values_set from .image_handling import convert_url_to_base64 +from pydantic import BaseModel def default_pt(messages): @@ -1001,6 +1002,8 @@ def infer_protocol_value( def _gemini_tool_call_invoke_helper( function_call_params: ChatCompletionToolCallFunctionChunk, ) -> Optional[VertexFunctionCall]: + if isinstance(function_call_params, BaseModel): + function_call_params = function_call_params.model_dump() name = function_call_params.get("name", "") or "" arguments = function_call_params.get("arguments", "") if ( @@ -1145,7 +1148,10 @@ def convert_to_gemini_tool_call_result( and prev_tool_call_id and msg_tool_call_id == prev_tool_call_id ): - name = tool.get("function", {}).get("name", "") + function_idf = tool.get("function") + if isinstance(function_idf, BaseModel): + function_idf = function_idf.model_dump() + name = getattr(function_idf, "get", lambda *_: "")("name", "") if not name: raise Exception( diff --git a/litellm/llms/custom_httpx/llm_http_handler.py b/litellm/llms/custom_httpx/llm_http_handler.py index d404077a5b62..2d0cef9be5b0 100644 --- a/litellm/llms/custom_httpx/llm_http_handler.py +++ b/litellm/llms/custom_httpx/llm_http_handler.py @@ -79,6 +79,7 @@ ModelResponse, ProviderConfigManager, ) +from pydantic import BaseModel if TYPE_CHECKING: from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj @@ -108,6 +109,16 @@ async def _make_common_async_call( provider_config.max_retry_on_unprocessable_entity_error ) + for message in data.get("messages", []): + if message.get("role") == "assistant" and message.get("tool_calls"): + for tool in message["tool_calls"]: + fn = tool.get("function") + if fn: + tool["function"] = ( + fn.model_dump() if isinstance(fn, BaseModel) + else getattr(fn, "__dict__", fn) + ) + response: Optional[httpx.Response] = None for i in range(max(max_retry_on_unprocessable_entity_error, 1)): try: