Skip to content

Commit 14397b2

Browse files
Improve message writer tests
1 parent 8f443fc commit 14397b2

File tree

6 files changed

+158
-147
lines changed

6 files changed

+158
-147
lines changed

communication/src/main/java/datadog/communication/serialization/Codec.java

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,36 @@
11
package datadog.communication.serialization;
22

3+
import datadog.communication.serialization.custom.aiguard.FunctionWriter;
34
import datadog.communication.serialization.custom.aiguard.MessageWriter;
45
import datadog.communication.serialization.custom.aiguard.ToolCallWriter;
56
import datadog.communication.serialization.custom.stacktrace.StackTraceEventFrameWriter;
67
import datadog.communication.serialization.custom.stacktrace.StackTraceEventWriter;
8+
import datadog.trace.api.Config;
79
import datadog.trace.api.aiguard.AIGuard;
810
import datadog.trace.util.stacktrace.StackTraceEvent;
911
import datadog.trace.util.stacktrace.StackTraceFrame;
1012
import java.nio.ByteBuffer;
1113
import java.nio.CharBuffer;
1214
import java.util.Collection;
1315
import java.util.Collections;
16+
import java.util.HashMap;
1417
import java.util.Map;
15-
import java.util.stream.Collectors;
16-
import java.util.stream.Stream;
1718

1819
public final class Codec extends ClassValue<ValueWriter<?>> {
1920

20-
private static final Map<Class<?>, ValueWriter<?>> defaultConfig =
21-
Stream.of(
22-
new Object[][] {
23-
{StackTraceEvent.class, new StackTraceEventWriter()},
24-
{StackTraceFrame.class, new StackTraceEventFrameWriter()},
25-
{AIGuard.Message.class, new MessageWriter()},
26-
{AIGuard.ToolCall.class, new ToolCallWriter()},
27-
})
28-
.collect(Collectors.toMap(data -> (Class<?>) data[0], data -> (ValueWriter<?>) data[1]));
29-
30-
public static final Codec INSTANCE = new Codec(defaultConfig);
21+
public static final Codec INSTANCE;
22+
23+
static {
24+
final Map<Class<?>, ValueWriter<?>> writers = new HashMap<>(1 << 3);
25+
writers.put(StackTraceEvent.class, new StackTraceEventWriter());
26+
writers.put(StackTraceFrame.class, new StackTraceEventFrameWriter());
27+
if (Config.get().isAiGuardEnabled()) {
28+
writers.put(AIGuard.Message.class, new MessageWriter());
29+
writers.put(AIGuard.ToolCall.class, new ToolCallWriter());
30+
writers.put(AIGuard.ToolCall.Function.class, new FunctionWriter());
31+
}
32+
INSTANCE = new Codec(writers);
33+
}
3134

3235
private final Map<Class<?>, ValueWriter<?>> config;
3336

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
package datadog.communication.serialization.custom.aiguard;
2+
3+
import datadog.communication.serialization.EncodingCache;
4+
import datadog.communication.serialization.ValueWriter;
5+
import datadog.communication.serialization.Writable;
6+
import datadog.trace.api.aiguard.AIGuard;
7+
8+
public class FunctionWriter implements ValueWriter<AIGuard.ToolCall.Function> {
9+
10+
@Override
11+
public void write(
12+
final AIGuard.ToolCall.Function function,
13+
final Writable writable,
14+
final EncodingCache encodingCache) {
15+
writable.startMap(2);
16+
writable.writeString("name", encodingCache);
17+
writable.writeString(function.getName(), encodingCache);
18+
writable.writeString("arguments", encodingCache);
19+
writable.writeString(function.getArguments(), encodingCache);
20+
}
21+
}

communication/src/main/java/datadog/communication/serialization/custom/aiguard/MessageWriter.java

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,7 @@ private static void writeToolCallArray(
4444
final EncodingCache encodingCache) {
4545
if (present) {
4646
writable.writeString(key, encodingCache);
47-
writable.startArray(values.size());
48-
for (final AIGuard.ToolCall toolCall : values) {
49-
writable.writeObject(toolCall, encodingCache);
50-
}
47+
writable.writeObject(values, encodingCache);
5148
}
5249
}
5350

communication/src/main/java/datadog/communication/serialization/custom/aiguard/ToolCallWriter.java

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,6 @@ public void write(
1414
writable.writeString("id", encodingCache);
1515
writable.writeString(value.getId(), encodingCache);
1616
writable.writeString("function", encodingCache);
17-
18-
final AIGuard.ToolCall.Function function = value.getFunction();
19-
if (function != null) {
20-
writable.startMap(2);
21-
writable.writeString("name", encodingCache);
22-
writable.writeString(function.getName(), encodingCache);
23-
writable.writeString("arguments", encodingCache);
24-
writable.writeString(function.getArguments(), encodingCache);
25-
}
17+
writable.writeObject(value.getFunction(), encodingCache);
2618
}
2719
}
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
package datadog.communication.serialization.aiguard
2+
3+
import datadog.communication.serialization.EncodingCache
4+
import datadog.communication.serialization.GrowableBuffer
5+
import datadog.communication.serialization.msgpack.MsgPackWriter
6+
import datadog.trace.api.aiguard.AIGuard
7+
import datadog.trace.test.util.DDSpecification
8+
import org.msgpack.core.MessagePack
9+
import org.msgpack.value.Value
10+
11+
import java.nio.charset.StandardCharsets
12+
import java.util.function.Function
13+
14+
class MessageWriterTest extends DDSpecification {
15+
16+
private EncodingCache encodingCache
17+
private GrowableBuffer buffer
18+
private MsgPackWriter writer
19+
20+
void setup() {
21+
injectSysConfig('ai_guard.enabled', 'true')
22+
final HashMap<CharSequence, byte[]> cache = new HashMap<>()
23+
encodingCache = new EncodingCache() {
24+
@Override
25+
byte[] encode(CharSequence chars) {
26+
cache.computeIfAbsent(chars, s -> s.toString().getBytes(StandardCharsets.UTF_8))
27+
}
28+
}
29+
buffer = new GrowableBuffer(1024)
30+
writer = new MsgPackWriter(buffer)
31+
}
32+
33+
void 'test write message'() {
34+
given:
35+
final message = AIGuard.Message.message('user', 'What day is today?')
36+
37+
when:
38+
writer.writeObject(message, encodingCache)
39+
40+
then:
41+
try (final unpacker = MessagePack.newDefaultUnpacker(buffer.slice())) {
42+
final value = asStringValueMap(unpacker.unpackValue())
43+
value.size() == 2
44+
value.role == 'user'
45+
value.content == 'What day is today?'
46+
}
47+
}
48+
49+
void 'test write tool call'() {
50+
given:
51+
final message =
52+
AIGuard.Message.assistant(
53+
AIGuard.ToolCall.toolCall('call_1', 'function_1', 'args_1'),
54+
AIGuard.ToolCall.toolCall('call_2', 'function_2', 'args_2'))
55+
56+
when:
57+
writer.writeObject(message, encodingCache)
58+
59+
then:
60+
try (final unpacker = MessagePack.newDefaultUnpacker(buffer.slice())) {
61+
final value = asStringKeyMap(unpacker.unpackValue())
62+
value.size() == 2
63+
asString(value.role) == 'assistant'
64+
65+
final toolCalls = value.get('tool_calls').asArrayValue().list()
66+
toolCalls.size() == 2
67+
68+
final firstCall = asStringKeyMap(toolCalls[0])
69+
asString(firstCall.id) == 'call_1'
70+
final firstFunction = asStringValueMap(firstCall.function)
71+
firstFunction.name == 'function_1'
72+
firstFunction.arguments == 'args_1'
73+
74+
final secondCall = asStringKeyMap(toolCalls[1])
75+
asString(secondCall.id) == 'call_2'
76+
final secondFunction = asStringValueMap(secondCall.function)
77+
secondFunction.name == 'function_2'
78+
secondFunction.arguments == 'args_2'
79+
}
80+
}
81+
82+
void 'test write tool output'() throws IOException {
83+
given:
84+
final message = AIGuard.Message.tool('call_1', 'output')
85+
86+
when:
87+
writer.writeObject(message, encodingCache)
88+
89+
then:
90+
try (final unpacker = MessagePack.newDefaultUnpacker(buffer.slice())) {
91+
final value = asStringValueMap(unpacker.unpackValue())
92+
value.size() == 3
93+
value.role == 'tool'
94+
value.tool_call_id == 'call_1'
95+
value.content == 'output'
96+
}
97+
}
98+
99+
private static <K, V> Map<K, V> mapValue(
100+
final Value values,
101+
final Function<Value, K> keyMapper,
102+
final Function<Value, V> valueMapper) {
103+
return values.asMapValue().entrySet().collectEntries {
104+
[(keyMapper.apply(it.key)): valueMapper.apply(it.value)]
105+
}
106+
}
107+
108+
private static Map<String, Value> asStringKeyMap(final Value values) {
109+
return mapValue(values, MessageWriterTest::asString, Function.identity())
110+
}
111+
112+
private static Map<String, String> asStringValueMap(final Value values) {
113+
return mapValue(values, MessageWriterTest::asString, MessageWriterTest::asString)
114+
}
115+
116+
private static String asString(final Value value) {
117+
return value.asStringValue().asString()
118+
}
119+
}

communication/src/test/groovy/datadog/communication/serialization/aiguard/MessageWriterTest.java

Lines changed: 0 additions & 121 deletions
This file was deleted.

0 commit comments

Comments
 (0)