Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions async-openai/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ native-tls-vendored = ["reqwest/native-tls-vendored"]
realtime = ["dep:tokio-tungstenite"]
# Bring your own types
byot = []
# Deserialize error responses yourself
string-errors = []

[dependencies]
async-openai-macros = { path = "../async-openai-macros", version = "0.1.0" }
Expand Down Expand Up @@ -59,6 +61,10 @@ serde_json = "1.0"
name = "bring-your-own-type"
required-features = ["byot"]

[[test]]
name = "string-errors"
required-features = ["string-errors", "byot"]

[package.metadata.docs.rs]
all-features = true
rustdoc-args = ["--cfg", "docsrs"]
7 changes: 7 additions & 0 deletions async-openai/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,13 @@ This can be useful in many scenarios:
Visit [examples/bring-your-own-type](https://github.com/64bit/async-openai/tree/main/examples/bring-your-own-type)
directory to learn more.

## String Errors

Enable the `string-errors` feature to receive API errors as raw strings instead of parsed structs. This can be useful
in scenarios where providers expose errors in different formats.

See [examples/string-errors](https://github.com/64bit/async-openai/tree/main/examples/string-errors) for usage.

## Dynamic Dispatch for Different Providers

For any struct that implements `Config` trait, you can wrap it in a smart pointer and cast the pointer to `dyn Config`
Expand Down
32 changes: 28 additions & 4 deletions async-openai/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@ use reqwest::{multipart::Form, Response};
use reqwest_eventsource::{Error as EventSourceError, Event, EventSource, RequestBuilderExt};
use serde::{de::DeserializeOwned, Serialize};

#[cfg(not(feature = "string-errors"))]
use crate::error::{ApiError, WrappedError};

use crate::{
config::{Config, OpenAIConfig},
error::{map_deserialization_error, ApiError, OpenAIError, StreamError, WrappedError},
error::{map_deserialization_error, OpenAIError, StreamError},
file::Files,
image::Images,
moderation::Moderations,
Expand Down Expand Up @@ -366,6 +369,7 @@ impl<C: Config> Client<C> {
Ok(bytes) => Ok(bytes),
Err(e) => {
match e {
#[cfg(not(feature = "string-errors"))]
OpenAIError::ApiError(api_error) => {
if status.is_server_error() {
Err(backoff::Error::Transient {
Expand All @@ -385,6 +389,17 @@ impl<C: Config> Client<C> {
Err(backoff::Error::Permanent(OpenAIError::ApiError(api_error)))
}
}
#[cfg(feature = "string-errors")]
OpenAIError::ApiError(api_error) => {
if status.is_server_error() {
Err(backoff::Error::Transient {
err: OpenAIError::ApiError(api_error),
retry_after: None,
})
} else {
Err(backoff::Error::Permanent(OpenAIError::ApiError(api_error)))
}
}
_ => Err(backoff::Error::Permanent(e)),
}
}
Expand Down Expand Up @@ -483,6 +498,7 @@ async fn read_response(response: Response) -> Result<Bytes, OpenAIError> {
let status = response.status();
let bytes = response.bytes().await.map_err(OpenAIError::Reqwest)?;

#[cfg(not(feature = "string-errors"))]
if status.is_server_error() {
// OpenAI does not guarantee server errors are returned as JSON so we cannot deserialize them.
let message: String = String::from_utf8_lossy(&bytes).into_owned();
Expand All @@ -497,10 +513,18 @@ async fn read_response(response: Response) -> Result<Bytes, OpenAIError> {

// Deserialize response body from either error object or actual response object
if !status.is_success() {
let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
.map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
#[cfg(not(feature = "string-errors"))]
{
let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
.map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
return Err(OpenAIError::ApiError(wrapped_error.error));
}

return Err(OpenAIError::ApiError(wrapped_error.error));
#[cfg(feature = "string-errors")]
{
let message: String = String::from_utf8_lossy(&bytes).into_owned();
return Err(OpenAIError::ApiError(crate::error::RawApiError(message)));
}
}

Ok(bytes)
Expand Down
28 changes: 28 additions & 0 deletions async-openai/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,42 @@

use serde::{Deserialize, Serialize};

/// Raw API error string from providers with custom error formats
///
/// Only available with the `string-errors` feature.
#[cfg(feature = "string-errors")]
#[derive(Debug, Clone)]
pub struct RawApiError(pub String);

#[cfg(feature = "string-errors")]
impl std::fmt::Display for RawApiError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}

#[cfg(feature = "string-errors")]
impl RawApiError {
/// Parse the raw error string into a custom type
pub fn parse<T: serde::de::DeserializeOwned>(&self) -> Result<T, serde_json::Error> {
serde_json::from_str(&self.0)
}
}

#[derive(Debug, thiserror::Error)]
pub enum OpenAIError {
/// Underlying error from reqwest library after an API call was made
#[error("http error: {0}")]
Reqwest(#[from] reqwest::Error),
/// OpenAI returns error object with details of API call failure
#[cfg(not(feature = "string-errors"))]
#[error("{0}")]
ApiError(ApiError),
/// Some OpenAI compatible services return error messages in diverge in error formats.
/// This feature leaves deserialization to the user, not even assuming json.
#[cfg(feature = "string-errors")]
#[error("{0}")]
ApiError(RawApiError),
/// Error when a response cannot be deserialized into a Rust type
#[error("failed to deserialize api response: error:{0} content:{1}")]
JSONDeserialize(serde_json::Error, String),
Expand Down
20 changes: 20 additions & 0 deletions async-openai/tests/string-errors.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#![allow(dead_code)]
//! The purpose of this test to make sure that with the string-errors feature enabled, the error is returned as a string.
//! Enabling the byot feature allows for a simpler test, as the body can be written as an empty json value.

use async_openai::{error::OpenAIError, Client};
use serde_json::{json, Value};

#[tokio::test]
async fn test_byot_errors() {
let client = Client::new();

let _r: Result<Value, OpenAIError> = client.chat().create_byot(json!({})).await;

match _r.unwrap_err() {
OpenAIError::ApiError(raw_error) => {
let _value: Value = raw_error.parse().unwrap();
}
_ => {}
};
}
10 changes: 10 additions & 0 deletions examples/string-errors/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
[package]
name = "string-errors"
version = "0.1.0"
edition = "2021"

[dependencies]
async-openai = { path = "../../async-openai", features = ["string-errors", "byot"] }
tokio = { version = "1.43.0", features = ["full"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
3 changes: 3 additions & 0 deletions examples/string-errors/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# String Errors Example

This example demonstrates how to use the `string-errors` feature to handle API errors from providers that use different error formats than OpenAI.
42 changes: 42 additions & 0 deletions examples/string-errors/src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
//! This example demonstrates how errors from OpenRouter can be parsed by the library consumer.
//! It uses the `string-errors` feature to receive API errors as raw strings instead of parsed structs.

use async_openai::{config::OpenAIConfig, error::OpenAIError, Client};
use serde::{Deserialize, Serialize};
use serde_json::json;

#[derive(Debug, Deserialize, Serialize)]
struct OpenRouterError {
code: i32,
message: String,
}

#[derive(Debug, Deserialize)]
struct ErrorWrapper {
error: OpenRouterError,
}

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let config = OpenAIConfig::new().with_api_base("https://openrouter.ai/api/v1");
let client = Client::with_config(config);

let result: Result<serde_json::Value, OpenAIError> = client
.chat()
.create_byot(json!({
"model": "invalid-model",
"messages": [{"role": "user", "content": "Hello"}]
}))
.await;

match result.unwrap_err() {
OpenAIError::ApiError(raw_error) => {
let error: ErrorWrapper = raw_error.parse().unwrap();
println!("Code: {}", error.error.code);
println!("Message: {}", error.error.message);
}
_ => panic!("Expected OpenAIError::ApiError"),
}

Ok(())
}
Loading