Skip to content

Commit 06070e7

Browse files
Mohammed, Ahmed yousri salama (Canada)Mohammed, Ahmed yousri salama (Canada)
authored andcommitted
Added support for OpenAI's content policy moderation API
1 parent 43dbaf4 commit 06070e7

File tree

23 files changed

+1964
-1
lines changed

23 files changed

+1964
-1
lines changed
Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
/*
2+
* Copyright 2024-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.openai;
18+
19+
import org.slf4j.Logger;
20+
import org.slf4j.LoggerFactory;
21+
import org.springframework.ai.model.ModelOptionsUtils;
22+
import org.springframework.ai.moderation.*;
23+
import org.springframework.ai.openai.api.OpenAiApi;
24+
import org.springframework.ai.openai.api.OpenAiModerationApi;
25+
import org.springframework.http.ResponseEntity;
26+
import org.springframework.retry.RetryCallback;
27+
import org.springframework.retry.RetryContext;
28+
import org.springframework.retry.RetryListener;
29+
import org.springframework.retry.support.RetryTemplate;
30+
import org.springframework.util.Assert;
31+
32+
import java.time.Duration;
33+
import java.util.ArrayList;
34+
import java.util.List;
35+
36+
/**
37+
* OpenAiModerationClient is a class that implements the ModerationClient interface. It
38+
* provides a client for calling the OpenAI moderation generation API.
39+
*
40+
* @author Ahmed Yousri
41+
* @since 0.9.0
42+
*/
43+
public class OpenAiModerationClient implements ModerationClient {
44+
45+
private final Logger logger = LoggerFactory.getLogger(getClass());
46+
47+
private OpenAiModerationOptions defaultOptions;
48+
49+
private final OpenAiModerationApi openAiModerationApi;
50+
51+
public final RetryTemplate retryTemplate = RetryTemplate.builder()
52+
.maxAttempts(10)
53+
.retryOn(OpenAiApi.OpenAiApiException.class)
54+
.exponentialBackoff(Duration.ofMillis(2000), 5, Duration.ofMillis(3 * 60000))
55+
.withListener(new RetryListener() {
56+
public <T extends Object, E extends Throwable> void onError(RetryContext context,
57+
RetryCallback<T, E> callback, Throwable throwable) {
58+
logger.warn("Retry error. Retry count:" + context.getRetryCount(), throwable);
59+
};
60+
})
61+
.build();
62+
63+
public OpenAiModerationClient(OpenAiModerationApi openAiModerationApi) {
64+
Assert.notNull(openAiModerationApi, "OpenAiModerationApi must not be null");
65+
this.openAiModerationApi = openAiModerationApi;
66+
}
67+
68+
public OpenAiModerationOptions getDefaultOptions() {
69+
return this.defaultOptions;
70+
}
71+
72+
public OpenAiModerationClient withDefaultOptions(OpenAiModerationOptions defaultOptions) {
73+
this.defaultOptions = defaultOptions;
74+
return this;
75+
}
76+
77+
@Override
78+
public ModerationResponse call(ModerationPrompt moderationPrompt) {
79+
return this.retryTemplate.execute(ctx -> {
80+
81+
String instructions = moderationPrompt.getInstructions().getText();
82+
83+
OpenAiModerationApi.OpenAiModerationRequest moderationRequest = new OpenAiModerationApi.OpenAiModerationRequest(
84+
instructions);
85+
86+
if (this.defaultOptions != null) {
87+
moderationRequest = ModelOptionsUtils.merge(this.defaultOptions, moderationRequest,
88+
OpenAiModerationApi.OpenAiModerationRequest.class);
89+
}
90+
91+
if (moderationPrompt.getOptions() != null) {
92+
moderationRequest = ModelOptionsUtils.merge(toOpenAiModerationOptions(moderationPrompt.getOptions()),
93+
moderationRequest, OpenAiModerationApi.OpenAiModerationRequest.class);
94+
}
95+
96+
ResponseEntity<OpenAiModerationApi.OpenAiModerationResponse> moderationResponseEntity = this.openAiModerationApi
97+
.createModeration(moderationRequest);
98+
99+
return convertResponse(moderationResponseEntity, moderationRequest);
100+
});
101+
}
102+
103+
private ModerationResponse convertResponse(
104+
ResponseEntity<OpenAiModerationApi.OpenAiModerationResponse> moderationResponseEntity,
105+
OpenAiModerationApi.OpenAiModerationRequest openAiModerationRequest) {
106+
OpenAiModerationApi.OpenAiModerationResponse moderationApiResponse = moderationResponseEntity.getBody();
107+
if (moderationApiResponse == null) {
108+
logger.warn("No moderation response returned for request: {}", openAiModerationRequest);
109+
return new ModerationResponse(Generation.NULL);
110+
}
111+
112+
List<ModerationResult> moderationResults = new ArrayList<>();
113+
if (moderationApiResponse.results() != null) {
114+
115+
for (OpenAiModerationApi.OpenAiModerationResult result : moderationApiResponse.results()) {
116+
Categories categories = null;
117+
CategoryScores categoryScores = null;
118+
if (result.categories() != null) {
119+
categories = Categories.builder()
120+
.withSexual(result.categories().sexual())
121+
.withHate(result.categories().hate())
122+
.withHarassment(result.categories().harassment())
123+
.withSelfHarm(result.categories().selfHarm())
124+
.withSexualMinors(result.categories().sexualMinors())
125+
.withHateThreatening(result.categories().hateThreatening())
126+
.withViolenceGraphic(result.categories().violenceGraphic())
127+
.withSelfHarmIntent(result.categories().selfHarmIntent())
128+
.withSelfHarmInstructions(result.categories().selfHarmInstructions())
129+
.withHarassmentThreatening(result.categories().harassmentThreatening())
130+
.withViolence(result.categories().violence())
131+
.build();
132+
}
133+
if (result.categoryScores() != null) {
134+
categoryScores = CategoryScores.builder()
135+
.withHate(result.categoryScores().hate())
136+
.withHateThreatening(result.categoryScores().hateThreatening())
137+
.withHarassment(result.categoryScores().harassment())
138+
.withHarassmentThreatening(result.categoryScores().harassmentThreatening())
139+
.withSelfHarm(result.categoryScores().selfHarm())
140+
.withSelfHarmIntent(result.categoryScores().selfHarmIntent())
141+
.withSelfHarmInstructions(result.categoryScores().selfHarmInstructions())
142+
.withSexual(result.categoryScores().sexual())
143+
.withSexualMinors(result.categoryScores().sexualMinors())
144+
.withViolence(result.categoryScores().violence())
145+
.withViolenceGraphic(result.categoryScores().violenceGraphic())
146+
.build();
147+
}
148+
ModerationResult moderationResult = ModerationResult.builder()
149+
.withCategories(categories)
150+
.withCategoryScores(categoryScores)
151+
.withFlagged(result.flagged())
152+
.build();
153+
moderationResults.add(moderationResult);
154+
}
155+
156+
}
157+
158+
Moderation moderation = Moderation.builder()
159+
.withId(moderationApiResponse.id())
160+
.withModel(moderationApiResponse.model())
161+
.withResults(moderationResults)
162+
.build();
163+
164+
return new ModerationResponse(new Generation(moderation));
165+
}
166+
167+
/**
168+
* Convert the {@link ModerationOptions} into {@link OpenAiModerationOptions}.
169+
* @return the converted {@link OpenAiModerationOptions}.
170+
*/
171+
private OpenAiModerationOptions toOpenAiModerationOptions(ModerationOptions runtimeModerationOptions) {
172+
OpenAiModerationOptions.Builder openAiModerationOptionsBuilder = OpenAiModerationOptions.builder();
173+
if (runtimeModerationOptions != null) {
174+
// Handle portable moderation options
175+
if (runtimeModerationOptions.getModel() != null) {
176+
openAiModerationOptionsBuilder.withModel(runtimeModerationOptions.getModel());
177+
}
178+
}
179+
return openAiModerationOptionsBuilder.build();
180+
}
181+
182+
}
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
/*
2+
* Copyright 2024-2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.ai.openai;
17+
18+
import com.fasterxml.jackson.annotation.JsonInclude;
19+
import com.fasterxml.jackson.annotation.JsonProperty;
20+
import org.springframework.ai.moderation.ModerationOptions;
21+
import org.springframework.ai.openai.api.OpenAiModerationApi;
22+
23+
/**
24+
* OpenAI Moderation API options. OpenAiModerationOptions.java
25+
*
26+
* @author Ahmed Yousri
27+
* @since 0.9.0
28+
*/
29+
@JsonInclude(JsonInclude.Include.NON_NULL)
30+
public class OpenAiModerationOptions implements ModerationOptions {
31+
32+
/**
33+
* The model to use for moderation generation.
34+
*/
35+
@JsonProperty("model")
36+
private String model = OpenAiModerationApi.DEFAULT_MODERATION_MODEL;
37+
38+
public static Builder builder() {
39+
return new Builder();
40+
}
41+
42+
public static class Builder {
43+
44+
private final OpenAiModerationOptions options;
45+
46+
private Builder() {
47+
this.options = new OpenAiModerationOptions();
48+
}
49+
50+
public Builder withModel(String model) {
51+
options.setModel(model);
52+
return this;
53+
}
54+
55+
public OpenAiModerationOptions build() {
56+
return options;
57+
}
58+
59+
}
60+
61+
@Override
62+
public String getModel() {
63+
return this.model;
64+
}
65+
66+
public void setModel(String model) {
67+
this.model = model;
68+
}
69+
70+
}

0 commit comments

Comments
 (0)