Skip to content

Commit 2c40207

Browse files
committed
feat(chat): Doctrine Dbal message store
1 parent a9b98a1 commit 2c40207

File tree

8 files changed

+484
-0
lines changed

8 files changed

+484
-0
lines changed
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
<?php
2+
3+
/*
4+
* This file is part of the Symfony package.
5+
*
6+
* (c) Fabien Potencier <[email protected]>
7+
*
8+
* For the full copyright and license information, please view the LICENSE
9+
* file that was distributed with this source code.
10+
*/
11+
12+
use Doctrine\DBAL\DriverManager;
13+
use Doctrine\DBAL\Tools\DsnParser;
14+
use Symfony\AI\Agent\Agent;
15+
use Symfony\AI\Chat\Bridge\Doctrine\DoctrineDbalMessageStore;
16+
use Symfony\AI\Chat\Chat;
17+
use Symfony\AI\Platform\Bridge\OpenAi\PlatformFactory;
18+
use Symfony\AI\Platform\Message\Message;
19+
use Symfony\AI\Platform\Message\MessageBag;
20+
21+
require_once dirname(__DIR__).'/bootstrap.php';
22+
23+
$platform = PlatformFactory::create(env('OPENAI_API_KEY'), http_client());
24+
25+
$connection = DriverManager::getConnection((new DsnParser())->parse('pdo-sqlite:///:memory:'));
26+
27+
$store = new DoctrineDbalMessageStore('symfony', $connection);
28+
$store->setup();
29+
30+
$agent = new Agent($platform, 'gpt-4o-mini');
31+
$chat = new Chat($agent, $store);
32+
33+
$messages = new MessageBag(
34+
Message::forSystem('You are a helpful assistant. You only answer with short sentences.'),
35+
);
36+
37+
$chat->initiate($messages);
38+
$chat->submit(Message::ofUser('My name is Christopher.'));
39+
$message = $chat->submit(Message::ofUser('What is my name?'));
40+
41+
echo $message->getContent().\PHP_EOL;

examples/commands/message-stores.php

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111

1212
require_once dirname(__DIR__).'/bootstrap.php';
1313

14+
use Doctrine\DBAL\DriverManager;
15+
use Doctrine\DBAL\Tools\DsnParser;
16+
use Symfony\AI\Chat\Bridge\Doctrine\DoctrineDbalMessageStore;
1417
use Symfony\AI\Chat\Bridge\HttpFoundation\SessionStore;
1518
use Symfony\AI\Chat\Bridge\Local\CacheStore;
1619
use Symfony\AI\Chat\Bridge\Local\InMemoryStore;
@@ -36,6 +39,10 @@
3639

3740
$factories = [
3841
'cache' => static fn (): CacheStore => new CacheStore(new ArrayAdapter(), cacheKey: 'symfony'),
42+
'doctrine' => static fn (): DoctrineDbalMessageStore => new DoctrineDbalMessageStore(
43+
'symfony',
44+
DriverManager::getConnection((new DsnParser())->parse('pdo-sqlite:///:memory:')),
45+
),
3946
'meilisearch' => static fn (): MeilisearchMessageStore => new MeilisearchMessageStore(
4047
http_client(),
4148
env('MEILISEARCH_HOST'),

src/ai-bundle/config/options.php

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -786,6 +786,21 @@
786786
->end()
787787
->end()
788788
->end()
789+
->arrayNode('doctrine')
790+
->children()
791+
->arrayNode('dbal')
792+
->useAttributeAsKey('name')
793+
->arrayPrototype()
794+
->children()
795+
->stringNode('connection')->cannotBeEmpty()->end()
796+
->stringNode('table_name')
797+
->info('The name of the message store will be used if the table_name is not set')
798+
->end()
799+
->end()
800+
->end()
801+
->end()
802+
->end()
803+
->end()
789804
->arrayNode('meilisearch')
790805
->useAttributeAsKey('name')
791806
->arrayPrototype()

src/ai-bundle/src/AiBundle.php

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
use Symfony\AI\AiBundle\Profiler\TraceablePlatform;
3737
use Symfony\AI\AiBundle\Profiler\TraceableToolbox;
3838
use Symfony\AI\AiBundle\Security\Attribute\IsGrantedTool;
39+
use Symfony\AI\Chat\Bridge\Doctrine\DoctrineDbalMessageStore;
3940
use Symfony\AI\Chat\Bridge\HttpFoundation\SessionStore;
4041
use Symfony\AI\Chat\Bridge\Local\CacheStore as CacheMessageStore;
4142
use Symfony\AI\Chat\Bridge\Meilisearch\MessageStore as MeilisearchMessageStore;
@@ -1495,6 +1496,26 @@ private function processMessageStoreConfig(string $type, array $messageStores, C
14951496
}
14961497
}
14971498

1499+
if ('doctrine' === $type) {
1500+
foreach ($messageStores['dbal'] ?? [] as $name => $dbalMessageStore) {
1501+
$definition = new Definition(DoctrineDbalMessageStore::class);
1502+
$definition
1503+
->setLazy(true)
1504+
->setArguments([
1505+
$dbalMessageStore['connection'],
1506+
$dbalMessageStore['table_name'] ?? $name,
1507+
new Reference(\sprintf('doctrine.dbal.%s_connection', $dbalMessageStore['connection'])),
1508+
new Reference('serializer'),
1509+
])
1510+
->addTag('proxy', ['interface' => MessageStoreInterface::class])
1511+
->addTag('ai.message_store');
1512+
1513+
$container->setDefinition('ai.message_store.'.$type.'.dbal.'.$name, $definition);
1514+
$container->registerAliasForArgument('ai.message_store.'.$type.'.'.$name, MessageStoreInterface::class, $name);
1515+
$container->registerAliasForArgument('ai.message_store.'.$type.'.'.$name, MessageStoreInterface::class, $type.'_'.$name);
1516+
}
1517+
}
1518+
14981519
if ('meilisearch' === $type) {
14991520
foreach ($messageStores as $name => $messageStore) {
15001521
$definition = new Definition(MeilisearchMessageStore::class);

src/ai-bundle/tests/DependencyInjection/AiBundleTest.php

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2851,6 +2851,67 @@ public function testCacheMessageStoreCanBeConfiguredWithCustomTtl()
28512851
$this->assertTrue($cacheMessageStoreDefinition->hasTag('ai.message_store'));
28522852
}
28532853

2854+
public function testDoctrineDbalMessageStoreCanBeConfiguredWithCustomKey()
2855+
{
2856+
$container = $this->buildContainer([
2857+
'ai' => [
2858+
'message_store' => [
2859+
'doctrine' => [
2860+
'dbal' => [
2861+
'default' => [
2862+
'connection' => 'default',
2863+
],
2864+
],
2865+
],
2866+
],
2867+
],
2868+
]);
2869+
2870+
$doctrineDbalDefaultMessageStoreDefinition = $container->getDefinition('ai.message_store.doctrine.dbal.default');
2871+
2872+
$this->assertSame('default', (string) $doctrineDbalDefaultMessageStoreDefinition->getArgument(0));
2873+
$this->assertSame('default', (string) $doctrineDbalDefaultMessageStoreDefinition->getArgument(1));
2874+
$this->assertInstanceOf(Reference::class, $doctrineDbalDefaultMessageStoreDefinition->getArgument(2));
2875+
$this->assertSame('doctrine.dbal.default_connection', (string) $doctrineDbalDefaultMessageStoreDefinition->getArgument(2));
2876+
$this->assertInstanceOf(Reference::class, $doctrineDbalDefaultMessageStoreDefinition->getArgument(3));
2877+
$this->assertSame('serializer', (string) $doctrineDbalDefaultMessageStoreDefinition->getArgument(3));
2878+
2879+
$this->assertTrue($doctrineDbalDefaultMessageStoreDefinition->hasTag('proxy'));
2880+
$this->assertSame([['interface' => MessageStoreInterface::class]], $doctrineDbalDefaultMessageStoreDefinition->getTag('proxy'));
2881+
$this->assertTrue($doctrineDbalDefaultMessageStoreDefinition->hasTag('ai.message_store'));
2882+
}
2883+
2884+
public function testDoctrineDbalMessageStoreWithCustomTableNameCanBeConfiguredWithCustomKey()
2885+
{
2886+
$container = $this->buildContainer([
2887+
'ai' => [
2888+
'message_store' => [
2889+
'doctrine' => [
2890+
'dbal' => [
2891+
'default' => [
2892+
'connection' => 'default',
2893+
'table_name' => 'foo',
2894+
],
2895+
],
2896+
],
2897+
],
2898+
],
2899+
]);
2900+
2901+
$doctrineDbalDefaultMessageStoreDefinition = $container->getDefinition('ai.message_store.doctrine.dbal.default');
2902+
2903+
$this->assertSame('default', (string) $doctrineDbalDefaultMessageStoreDefinition->getArgument(0));
2904+
$this->assertSame('foo', (string) $doctrineDbalDefaultMessageStoreDefinition->getArgument(1));
2905+
$this->assertInstanceOf(Reference::class, $doctrineDbalDefaultMessageStoreDefinition->getArgument(2));
2906+
$this->assertSame('doctrine.dbal.default_connection', (string) $doctrineDbalDefaultMessageStoreDefinition->getArgument(2));
2907+
$this->assertInstanceOf(Reference::class, $doctrineDbalDefaultMessageStoreDefinition->getArgument(3));
2908+
$this->assertSame('serializer', (string) $doctrineDbalDefaultMessageStoreDefinition->getArgument(3));
2909+
2910+
$this->assertTrue($doctrineDbalDefaultMessageStoreDefinition->hasTag('proxy'));
2911+
$this->assertSame([['interface' => MessageStoreInterface::class]], $doctrineDbalDefaultMessageStoreDefinition->getTag('proxy'));
2912+
$this->assertTrue($doctrineDbalDefaultMessageStoreDefinition->hasTag('ai.message_store'));
2913+
}
2914+
28542915
public function testMeilisearchMessageStoreIsConfigured()
28552916
{
28562917
$container = $this->buildContainer([
@@ -3350,6 +3411,14 @@ private function getFullConfig(): array
33503411
'key' => 'foo',
33513412
],
33523413
],
3414+
'doctrine' => [
3415+
'dbal' => [
3416+
'default' => [
3417+
'connection' => 'default',
3418+
'table_name' => 'foo',
3419+
],
3420+
],
3421+
],
33533422
'memory' => [
33543423
'my_memory_message_store' => [
33553424
'identifier' => '_memory',

src/chat/composer.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
},
2727
"require-dev": {
2828
"ext-redis": "*",
29+
"doctrine/dbal": "^3.3 || ^4.0",
2930
"phpstan/phpstan": "^2.0",
3031
"phpstan/phpstan-strict-rules": "^2.0",
3132
"phpunit/phpunit": "^11.5.13",
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
<?php
2+
3+
/*
4+
* This file is part of the Symfony package.
5+
*
6+
* (c) Fabien Potencier <[email protected]>
7+
*
8+
* For the full copyright and license information, please view the LICENSE
9+
* file that was distributed with this source code.
10+
*/
11+
12+
namespace Symfony\AI\Chat\Bridge\Doctrine;
13+
14+
use Doctrine\DBAL\Connection;
15+
use Doctrine\DBAL\Connection as DBALConnection;
16+
use Doctrine\DBAL\Platforms\OraclePlatform;
17+
use Doctrine\DBAL\Result;
18+
use Doctrine\DBAL\Schema\Name\Identifier;
19+
use Doctrine\DBAL\Schema\Name\UnqualifiedName;
20+
use Doctrine\DBAL\Schema\PrimaryKeyConstraint;
21+
use Doctrine\DBAL\Schema\Schema;
22+
use Doctrine\DBAL\Types\Types;
23+
use Symfony\AI\Chat\Exception\InvalidArgumentException;
24+
use Symfony\AI\Chat\ManagedStoreInterface;
25+
use Symfony\AI\Chat\MessageNormalizer;
26+
use Symfony\AI\Chat\MessageStoreInterface;
27+
use Symfony\AI\Platform\Message\MessageBag;
28+
use Symfony\AI\Platform\Message\MessageInterface;
29+
use Symfony\Component\Serializer\Encoder\JsonEncoder;
30+
use Symfony\Component\Serializer\Normalizer\ArrayDenormalizer;
31+
use Symfony\Component\Serializer\Serializer;
32+
use Symfony\Component\Serializer\SerializerInterface;
33+
34+
/**
35+
* @author Guillaume Loulier <[email protected]>
36+
*/
37+
final class DoctrineDbalMessageStore implements ManagedStoreInterface, MessageStoreInterface
38+
{
39+
public function __construct(
40+
private readonly string $tableName,
41+
private readonly DBALConnection $dbalConnection,
42+
private readonly SerializerInterface $serializer = new Serializer([
43+
new ArrayDenormalizer(),
44+
new MessageNormalizer(),
45+
], [new JsonEncoder()]),
46+
) {
47+
}
48+
49+
public function setup(array $options = []): void
50+
{
51+
if ([] !== $options) {
52+
throw new InvalidArgumentException('No supported options.');
53+
}
54+
55+
$schema = $this->dbalConnection->createSchemaManager()->introspectSchema();
56+
57+
if ($schema->hasTable($this->tableName)) {
58+
return;
59+
}
60+
61+
$this->addTableToSchema($schema);
62+
}
63+
64+
public function drop(): void
65+
{
66+
$schema = $this->dbalConnection->createSchemaManager()->introspectSchema();
67+
68+
if (!$schema->hasTable($this->tableName)) {
69+
return;
70+
}
71+
72+
$queryBuilder = $this->dbalConnection->createQueryBuilder()
73+
->delete($this->tableName);
74+
75+
$this->dbalConnection->transactional(fn (Connection $connection): Result => $connection->executeQuery(
76+
$queryBuilder->getSQL(),
77+
));
78+
}
79+
80+
public function save(MessageBag $messages): void
81+
{
82+
$queryBuilder = $this->dbalConnection->createQueryBuilder()
83+
->insert($this->tableName)
84+
->values([
85+
'messages' => '?',
86+
]);
87+
88+
$this->dbalConnection->transactional(fn (Connection $connection): Result => $connection->executeQuery(
89+
$queryBuilder->getSQL(),
90+
[
91+
$this->serializer->serialize($messages->getMessages(), 'json'),
92+
],
93+
$queryBuilder->getParameterTypes(),
94+
));
95+
}
96+
97+
public function load(): MessageBag
98+
{
99+
$queryBuilder = $this->dbalConnection->createQueryBuilder()
100+
->select('messages')
101+
->from($this->tableName)
102+
;
103+
104+
$result = $this->dbalConnection->transactional(static fn (Connection $connection): Result => $connection->executeQuery(
105+
$queryBuilder->getSQL(),
106+
));
107+
108+
$messages = array_map(
109+
fn (array $payload): array => $this->serializer->deserialize($payload['messages'], MessageInterface::class.'[]', 'json'),
110+
$result->fetchAllAssociative(),
111+
);
112+
113+
return new MessageBag(...array_merge(...$messages));
114+
}
115+
116+
private function addTableToSchema(Schema $schema): void
117+
{
118+
$table = $schema->createTable($this->tableName);
119+
$table->addOption('_symfony_ai_chat_table_name', $this->tableName);
120+
$idColumn = $table->addColumn('id', Types::BIGINT)
121+
->setAutoincrement(true)
122+
->setNotnull(true);
123+
$table->addColumn('messages', Types::TEXT)
124+
->setNotnull(true);
125+
if (class_exists(PrimaryKeyConstraint::class)) {
126+
$table->addPrimaryKeyConstraint(new PrimaryKeyConstraint(null, [
127+
new UnqualifiedName(Identifier::unquoted('id')),
128+
], true));
129+
} else {
130+
$table->setPrimaryKey(['id']);
131+
}
132+
133+
// We need to create a sequence for Oracle and set the id column to get the correct nextval
134+
if ($this->dbalConnection->getDatabasePlatform() instanceof OraclePlatform) {
135+
$serverVersion = $this->dbalConnection->executeQuery("SELECT version FROM product_component_version WHERE product LIKE 'Oracle Database%'")->fetchOne();
136+
if (version_compare($serverVersion, '12.1.0', '>=')) {
137+
$idColumn->setAutoincrement(false); // disable the creation of SEQUENCE and TRIGGER
138+
$idColumn->setDefault($this->tableName.'_seq.nextval');
139+
140+
$schema->createSequence($this->tableName.'_seq');
141+
}
142+
}
143+
144+
foreach ($schema->toSql($this->dbalConnection->getDatabasePlatform()) as $sql) {
145+
$this->dbalConnection->executeQuery($sql);
146+
}
147+
}
148+
}

0 commit comments

Comments
 (0)