diff --git a/async-openai/Cargo.toml b/async-openai/Cargo.toml index c3707e69..1a6e1bcd 100644 --- a/async-openai/Cargo.toml +++ b/async-openai/Cargo.toml @@ -25,6 +25,8 @@ native-tls-vendored = ["reqwest/native-tls-vendored"] realtime = ["dep:tokio-tungstenite"] # Bring your own types byot = [] +# Deserialize error responses yourself +string-errors = [] [dependencies] async-openai-macros = { path = "../async-openai-macros", version = "0.1.0" } @@ -59,6 +61,10 @@ serde_json = "1.0" name = "bring-your-own-type" required-features = ["byot"] +[[test]] +name = "string-errors" +required-features = ["string-errors", "byot"] + [package.metadata.docs.rs] all-features = true rustdoc-args = ["--cfg", "docsrs"] diff --git a/async-openai/README.md b/async-openai/README.md index 9b1fdcab..a93151db 100644 --- a/async-openai/README.md +++ b/async-openai/README.md @@ -145,6 +145,13 @@ This can be useful in many scenarios: Visit [examples/bring-your-own-type](https://github.com/64bit/async-openai/tree/main/examples/bring-your-own-type) directory to learn more. +## String Errors + +Enable the `string-errors` feature to receive API errors as raw strings instead of parsed structs. This can be useful +in scenarios where providers expose errors in different formats. + +See [examples/string-errors](https://github.com/64bit/async-openai/tree/main/examples/string-errors) for usage. + ## Dynamic Dispatch for Different Providers For any struct that implements `Config` trait, you can wrap it in a smart pointer and cast the pointer to `dyn Config` diff --git a/async-openai/src/client.rs b/async-openai/src/client.rs index d73a2329..53737f91 100644 --- a/async-openai/src/client.rs +++ b/async-openai/src/client.rs @@ -6,9 +6,12 @@ use reqwest::{multipart::Form, Response}; use reqwest_eventsource::{Error as EventSourceError, Event, EventSource, RequestBuilderExt}; use serde::{de::DeserializeOwned, Serialize}; +#[cfg(not(feature = "string-errors"))] +use crate::error::{ApiError, WrappedError}; + use crate::{ config::{Config, OpenAIConfig}, - error::{map_deserialization_error, ApiError, OpenAIError, StreamError, WrappedError}, + error::{map_deserialization_error, OpenAIError, StreamError}, file::Files, image::Images, moderation::Moderations, @@ -366,6 +369,7 @@ impl Client { Ok(bytes) => Ok(bytes), Err(e) => { match e { + #[cfg(not(feature = "string-errors"))] OpenAIError::ApiError(api_error) => { if status.is_server_error() { Err(backoff::Error::Transient { @@ -385,6 +389,17 @@ impl Client { Err(backoff::Error::Permanent(OpenAIError::ApiError(api_error))) } } + #[cfg(feature = "string-errors")] + OpenAIError::ApiError(api_error) => { + if status.is_server_error() { + Err(backoff::Error::Transient { + err: OpenAIError::ApiError(api_error), + retry_after: None, + }) + } else { + Err(backoff::Error::Permanent(OpenAIError::ApiError(api_error))) + } + } _ => Err(backoff::Error::Permanent(e)), } } @@ -483,6 +498,7 @@ async fn read_response(response: Response) -> Result { let status = response.status(); let bytes = response.bytes().await.map_err(OpenAIError::Reqwest)?; + #[cfg(not(feature = "string-errors"))] if status.is_server_error() { // OpenAI does not guarantee server errors are returned as JSON so we cannot deserialize them. let message: String = String::from_utf8_lossy(&bytes).into_owned(); @@ -497,10 +513,18 @@ async fn read_response(response: Response) -> Result { // Deserialize response body from either error object or actual response object if !status.is_success() { - let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref()) - .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?; + #[cfg(not(feature = "string-errors"))] + { + let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref()) + .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?; + return Err(OpenAIError::ApiError(wrapped_error.error)); + } - return Err(OpenAIError::ApiError(wrapped_error.error)); + #[cfg(feature = "string-errors")] + { + let message: String = String::from_utf8_lossy(&bytes).into_owned(); + return Err(OpenAIError::ApiError(crate::error::RawApiError(message))); + } } Ok(bytes) diff --git a/async-openai/src/error.rs b/async-openai/src/error.rs index 288d198b..e435b8df 100644 --- a/async-openai/src/error.rs +++ b/async-openai/src/error.rs @@ -2,14 +2,42 @@ use serde::{Deserialize, Serialize}; +/// Raw API error string from providers with custom error formats +/// +/// Only available with the `string-errors` feature. +#[cfg(feature = "string-errors")] +#[derive(Debug, Clone)] +pub struct RawApiError(pub String); + +#[cfg(feature = "string-errors")] +impl std::fmt::Display for RawApiError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +#[cfg(feature = "string-errors")] +impl RawApiError { + /// Parse the raw error string into a custom type + pub fn parse(&self) -> Result { + serde_json::from_str(&self.0) + } +} + #[derive(Debug, thiserror::Error)] pub enum OpenAIError { /// Underlying error from reqwest library after an API call was made #[error("http error: {0}")] Reqwest(#[from] reqwest::Error), /// OpenAI returns error object with details of API call failure + #[cfg(not(feature = "string-errors"))] #[error("{0}")] ApiError(ApiError), + /// Some OpenAI compatible services return error messages in diverge in error formats. + /// This feature leaves deserialization to the user, not even assuming json. + #[cfg(feature = "string-errors")] + #[error("{0}")] + ApiError(RawApiError), /// Error when a response cannot be deserialized into a Rust type #[error("failed to deserialize api response: error:{0} content:{1}")] JSONDeserialize(serde_json::Error, String), diff --git a/async-openai/tests/string-errors.rs b/async-openai/tests/string-errors.rs new file mode 100644 index 00000000..4819005a --- /dev/null +++ b/async-openai/tests/string-errors.rs @@ -0,0 +1,20 @@ +#![allow(dead_code)] +//! The purpose of this test to make sure that with the string-errors feature enabled, the error is returned as a string. +//! Enabling the byot feature allows for a simpler test, as the body can be written as an empty json value. + +use async_openai::{error::OpenAIError, Client}; +use serde_json::{json, Value}; + +#[tokio::test] +async fn test_byot_errors() { + let client = Client::new(); + + let _r: Result = client.chat().create_byot(json!({})).await; + + match _r.unwrap_err() { + OpenAIError::ApiError(raw_error) => { + let _value: Value = raw_error.parse().unwrap(); + } + _ => {} + }; +} diff --git a/examples/string-errors/Cargo.toml b/examples/string-errors/Cargo.toml new file mode 100644 index 00000000..f984bf8d --- /dev/null +++ b/examples/string-errors/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "string-errors" +version = "0.1.0" +edition = "2021" + +[dependencies] +async-openai = { path = "../../async-openai", features = ["string-errors", "byot"] } +tokio = { version = "1.43.0", features = ["full"] } +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" diff --git a/examples/string-errors/README.md b/examples/string-errors/README.md new file mode 100644 index 00000000..7c09301a --- /dev/null +++ b/examples/string-errors/README.md @@ -0,0 +1,3 @@ +# String Errors Example + +This example demonstrates how to use the `string-errors` feature to handle API errors from providers that use different error formats than OpenAI. \ No newline at end of file diff --git a/examples/string-errors/src/main.rs b/examples/string-errors/src/main.rs new file mode 100644 index 00000000..0ad07698 --- /dev/null +++ b/examples/string-errors/src/main.rs @@ -0,0 +1,42 @@ +//! This example demonstrates how errors from OpenRouter can be parsed by the library consumer. +//! It uses the `string-errors` feature to receive API errors as raw strings instead of parsed structs. + +use async_openai::{config::OpenAIConfig, error::OpenAIError, Client}; +use serde::{Deserialize, Serialize}; +use serde_json::json; + +#[derive(Debug, Deserialize, Serialize)] +struct OpenRouterError { + code: i32, + message: String, +} + +#[derive(Debug, Deserialize)] +struct ErrorWrapper { + error: OpenRouterError, +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + let config = OpenAIConfig::new().with_api_base("https://openrouter.ai/api/v1"); + let client = Client::with_config(config); + + let result: Result = client + .chat() + .create_byot(json!({ + "model": "invalid-model", + "messages": [{"role": "user", "content": "Hello"}] + })) + .await; + + match result.unwrap_err() { + OpenAIError::ApiError(raw_error) => { + let error: ErrorWrapper = raw_error.parse().unwrap(); + println!("Code: {}", error.error.code); + println!("Message: {}", error.error.message); + } + _ => panic!("Expected OpenAIError::ApiError"), + } + + Ok(()) +}