Skip to content

Commit 5ff2c3b

Browse files
JoeCP17markpollack
authored andcommitted
modify message type get value
1 parent cffa790 commit 5ff2c3b

File tree

4 files changed

+13
-5
lines changed

4 files changed

+13
-5
lines changed

spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/client/AzureOpenAiClient.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ public AiResponse generate(Prompt prompt) {
7373
List<Message> messages = prompt.getMessages();
7474
List<ChatMessage> azureMessages = new ArrayList<>();
7575
for (Message message : messages) {
76-
String messageType = message.getMessageType().getValue();
76+
String messageType = message.getMessageTypeValue();
7777
ChatRole chatRole = ChatRole.fromString(messageType);
7878
azureMessages.add(new ChatMessage(chatRole, message.getContent()));
7979
}

spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/AbstractMessage.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,4 +87,9 @@ public MessageType getMessageType() {
8787
return this.messageType;
8888
}
8989

90+
@Override
91+
public String getMessageTypeValue() {
92+
return this.messageType.getValue();
93+
}
94+
9095
}

spring-ai-core/src/main/java/org/springframework/ai/prompt/messages/Message.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,6 @@ public interface Message {
2626

2727
MessageType getMessageType();
2828

29+
String getMessageTypeValue();
30+
2931
}

spring-ai-openai/src/main/java/org/springframework/ai/openai/client/OpenAiClient.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.springframework.ai.client.Generation;
2828
import org.springframework.ai.prompt.Prompt;
2929
import org.springframework.ai.prompt.messages.Message;
30+
import org.springframework.ai.prompt.messages.MessageType;
3031
import org.springframework.util.Assert;
3132

3233
import java.util.ArrayList;
@@ -79,7 +80,7 @@ public AiResponse generate(Prompt prompt) {
7980
List<Message> messages = prompt.getMessages();
8081
List<ChatMessage> theoMessages = new ArrayList<>();
8182
for (Message message : messages) {
82-
String messageType = message.getMessageType().getValue();
83+
String messageType = message.getMessageTypeValue();
8384
theoMessages.add(new ChatMessage(messageType, message.getContent()));
8485
}
8586
ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder()
@@ -148,14 +149,14 @@ private List<ChatMessage> convertToChatMessages(List<Message> messages) {
148149
for (Message promptMessage : messages) {
149150
switch (promptMessage.getMessageType()) {
150151
case USER:
151-
chatMessages.add(new ChatMessage("user", promptMessage.getContent()));
152+
chatMessages.add(new ChatMessage(MessageType.USER.getValue(), promptMessage.getContent()));
152153
break;
153154
case ASSISTANT:
154155
// TODO - valid?
155-
chatMessages.add(new ChatMessage("assistant", promptMessage.getContent()));
156+
chatMessages.add(new ChatMessage(MessageType.ASSISTANT.getValue(), promptMessage.getContent()));
156157
break;
157158
case SYSTEM:
158-
chatMessages.add(new ChatMessage("system", promptMessage.getContent()));
159+
chatMessages.add(new ChatMessage(MessageType.SYSTEM.getValue(), promptMessage.getContent()));
159160
break;
160161
case FUNCTION:
161162
logger.error(

0 commit comments

Comments
 (0)