Skip to content

Commit e39788a

Browse files
authored
feat(templates): refactor vector embedding in bit Boilerplate #11448 (#11450)
1 parent 7e32552 commit e39788a

File tree

13 files changed

+115
-69
lines changed

13 files changed

+115
-69
lines changed

src/Templates/Boilerplate/Bit.Boilerplate/.template.config/template.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -526,7 +526,7 @@
526526
]
527527
},
528528
{
529-
"condition": "(signalR != true && database != PostgreSQL && database != SqlServer)",
529+
"condition": "(database != PostgreSQL && database != SqlServer)",
530530
"exclude": [
531531
"src/Server/Boilerplate.Server.Api/Services/ProductEmbeddingService.cs"
532532
]

src/Templates/Boilerplate/Bit.Boilerplate/src/Client/Boilerplate.Client.Core/Components/Pages/Products/ProductsPage.razor.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ private void PrepareGridDataProvider()
8080
var queriedRequest = productController.WithQuery(query.ToString());
8181
var data = await (string.IsNullOrWhiteSpace(searchQuery)
8282
? queriedRequest.GetProducts(req.CancellationToken)
83-
: queriedRequest.GetProductsBySearchQuery(searchQuery, req.CancellationToken));
83+
: queriedRequest.SearchProducts(searchQuery, req.CancellationToken));
8484

8585
return BitDataGridItemsProviderResult.From(data!.Items!, (int)data!.TotalCount);
8686
}

src/Templates/Boilerplate/Bit.Boilerplate/src/Directory.Packages.props

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@
8181
<PackageVersion Condition=" ('$(signalR)' == 'true' OR '$(signalR)' == '') OR ('$(database)' == 'PostgreSQL' OR '$(database)' == '') OR ('$(database)' == 'SqlServer' OR '$(database)' == '') " Include="Microsoft.Extensions.AI" Version="9.9.0" />
8282
<PackageVersion Condition=" ('$(signalR)' == 'true' OR '$(signalR)' == '') OR ('$(database)' == 'PostgreSQL' OR '$(database)' == '') OR ('$(database)' == 'SqlServer' OR '$(database)' == '') " Include="Microsoft.Extensions.AI.OpenAI" Version="9.9.0-preview.1.25458.4" />
8383
<PackageVersion Condition=" ('$(signalR)' == 'true' OR '$(signalR)' == '') OR ('$(database)' == 'PostgreSQL' OR '$(database)' == '') OR ('$(database)' == 'SqlServer' OR '$(database)' == '') " Include="Microsoft.Extensions.AI.AzureAIInference" Version="9.9.0-preview.1.25458.4" />
84+
<PackageVersion Condition=" ('$(signalR)' == 'true' OR '$(signalR)' == '') OR ('$(database)' == 'PostgreSQL' OR '$(database)' == '') OR ('$(database)' == 'SqlServer' OR '$(database)' == '') " Include="SmartComponents.LocalEmbeddings.SemanticKernel" Version="0.1.0-preview10148" />
85+
<PackageVersion Condition=" ('$(signalR)' == 'true' OR '$(signalR)' == '') OR ('$(database)' == 'PostgreSQL' OR '$(database)' == '') OR ('$(database)' == 'SqlServer' OR '$(database)' == '') " Include="Microsoft.SemanticKernel.Core" Version="1.65.0" />
8486
<PackageVersion Condition=" ('$(database)' == 'PostgreSQL' OR '$(database)' == '') " Include="Pgvector.EntityFrameworkCore" Version="0.2.2" />
8587
<PackageVersion Condition="'$(module)' == 'Admin' OR '$(module)' == ''" Include="Newtonsoft.Json" Version="13.0.4" />
8688
<PackageVersion Condition=" '$(appInsights)' == 'true' OR '$(appInsights)' == '' " Include="Microsoft.Extensions.Logging.ApplicationInsights" Version="2.23.0" />

src/Templates/Boilerplate/Bit.Boilerplate/src/Server/Boilerplate.Server.Api/Boilerplate.Server.Api.csproj

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,9 @@
6363
<PackageReference Condition=" ('$(signalR)' == 'true' OR '$(signalR)' == '') OR ('$(database)' == 'PostgreSQL' OR '$(database)' == '') OR ('$(database)' == 'SqlServer' OR '$(database)' == '') " Include="Microsoft.Extensions.AI" />
6464
<PackageReference Condition=" ('$(signalR)' == 'true' OR '$(signalR)' == '') OR ('$(database)' == 'PostgreSQL' OR '$(database)' == '') OR ('$(database)' == 'SqlServer' OR '$(database)' == '') " Include="Microsoft.Extensions.AI.AzureAIInference" />
6565
<PackageReference Condition=" ('$(signalR)' == 'true' OR '$(signalR)' == '') OR ('$(database)' == 'PostgreSQL' OR '$(database)' == '') OR ('$(database)' == 'SqlServer' OR '$(database)' == '') " Include="Microsoft.Extensions.AI.OpenAI" />
66+
<PackageReference Condition=" ('$(signalR)' == 'true' OR '$(signalR)' == '') OR ('$(database)' == 'PostgreSQL' OR '$(database)' == '') OR ('$(database)' == 'SqlServer' OR '$(database)' == '') " Include="Microsoft.SemanticKernel.Core" />
67+
<PackageReference Condition=" ('$(signalR)' == 'true' OR '$(signalR)' == '') OR ('$(database)' == 'PostgreSQL' OR '$(database)' == '') OR ('$(database)' == 'SqlServer' OR '$(database)' == '') " Include="SmartComponents.LocalEmbeddings.SemanticKernel" />
6668
<PackageReference Condition=" ('$(database)' == 'PostgreSQL' OR '$(database)' == '') " Include="Pgvector.EntityFrameworkCore" />
67-
6869
<Using Include="Microsoft.EntityFrameworkCore.Migrations" />
6970
<Using Include="Microsoft.EntityFrameworkCore.Metadata.Builders" />
7071
<Using Include="Microsoft.AspNetCore.Identity.EntityFrameworkCore" />

src/Templates/Boilerplate/Bit.Boilerplate/src/Server/Boilerplate.Server.Api/Controllers/Products/ProductController.cs

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ public partial class ProductController : AppControllerBase, IProductController
2121
//#if (signalR == true)
2222
[AutoInject] private IHubContext<AppHub> appHubContext = default!;
2323
//#endif
24-
//#if (signalR == true || database == "PostgreSQL" || database == "SqlServer")
24+
//#if (database == "PostgreSQL" || database == "SqlServer")
2525
[AutoInject] private ProductEmbeddingService productEmbeddingService = default!;
2626
//#endif
2727
[AutoInject] private ResponseCacheService responseCacheService = default!;
@@ -47,10 +47,10 @@ public async Task<PagedResult<ProductDto>> GetProducts(ODataQueryOptions<Product
4747
}
4848

4949
[HttpGet("{searchQuery}")]
50-
public async Task<PagedResult<ProductDto>> GetProductsBySearchQuery(string searchQuery, ODataQueryOptions<ProductDto> odataQuery, CancellationToken cancellationToken)
50+
public async Task<PagedResult<ProductDto>> SearchProducts(string searchQuery, ODataQueryOptions<ProductDto> odataQuery, CancellationToken cancellationToken)
5151
{
5252
//#if (database == "PostgreSQL" || database == "SqlServer")
53-
var query = (IQueryable<ProductDto>)odataQuery.ApplyTo((await (productEmbeddingService.GetProductsBySearchQuery(searchQuery, cancellationToken))).Project(),
53+
var query = (IQueryable<ProductDto>)odataQuery.ApplyTo((await (productEmbeddingService.SearchProducts(searchQuery, cancellationToken))).Project(),
5454
ignoreQueryOptions: AllowedQueryOptions.Top | AllowedQueryOptions.Skip | AllowedQueryOptions.OrderBy /* Ordering can disrupt the results of the embedding service. */);
5555
var totalCount = await query.LongCountAsync(cancellationToken);
5656

@@ -59,9 +59,7 @@ public async Task<PagedResult<ProductDto>> GetProductsBySearchQuery(string searc
5959

6060
return new PagedResult<ProductDto>(await query.ToArrayAsync(cancellationToken), totalCount);
6161
//#else
62-
// Embedding based search is only implemented for PostgreSQL.
63-
// Simply return whole products list.
64-
return await GetProducts(odataQuery, cancellationToken);
62+
throw new NotImplementedException(); // Embedding based search is only implemented for PostgreSQL and SQL Server only.
6563
//#endif
6664
}
6765

src/Templates/Boilerplate/Bit.Boilerplate/src/Server/Boilerplate.Server.Api/Data/Configurations/Chatbot/SystemPromptConfiguration.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,8 @@ These are the primary functional areas of the application beyond account managem
174174
- If the user asks multiple questions, list them back to the user to confirm understanding, then address each one separately with clear headings. If needed, ask them to prioritize: ""I see you have multiple questions. Which issue would you like me to address first?""
175175
176176
- Never request sensitive information (e.g., passwords, PINs). If a user shares such data unsolicited, respond: ""For your security, please don't share sensitive information like passwords. Rest assured, your data is safe with us."" " +
177-
//#if (module == 'Sales')
177+
//#if (module == "Sales")
178+
//#if (database == "PostgreSQL" || database == "SqlServer")
178179
@"### Handling Car Recommendation Requests:
179180
**[[[CAR_RECOMMENDATION_RULES_BEGIN]]]**
180181
* **If a user asks for help choosing a car, for recommendations, or expresses purchase intent (e.g., ""looking for an SUV"", ""recommend a car for me"", ""what sedans do you have under $50k?""):**
@@ -195,6 +196,7 @@ These are the primary functional areas of the application beyond account managem
195196
**[[[CAR_RECOMMENDATION_RULES_END]]]**
196197
" +
197198
//#endif
199+
//#endif
198200
//#if (ads == true)
199201
@"### Handling advertisement trouble requests:
200202
**[[[ADS_TROUBLE_RULES_BEGIN]]]""

src/Templates/Boilerplate/Bit.Boilerplate/src/Server/Boilerplate.Server.Api/Data/Configurations/Product/ProductConfiguration.cs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,12 @@ public void Configure(EntityTypeBuilder<Product> builder)
1919
//#if (database == "PostgreSQL" || database == "SqlServer")
2020
if (AppDbContext.IsEmbeddingEnabled)
2121
{
22-
builder.Property(p => p.Embedding).HasColumnType("vector(1536)"); // 1536 for text-embedding-3-small
22+
builder.Property(p => p.Embedding).HasColumnType("vector(384)"); // Checkout appsettings.json's AI:EmbeddingOptions:Dimensions
23+
//#if (database == "PostgreSQL")
24+
builder.HasIndex(m => m.Embedding)
25+
.HasMethod("hnsw") // ivfflat
26+
.HasOperators("vector_cosine_ops");
27+
//#endif
2328
}
2429
else
2530
{

src/Templates/Boilerplate/Bit.Boilerplate/src/Server/Boilerplate.Server.Api/Program.Services.cs

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
using System.IO.Compression;
55
//#if (signalR == true || database == "PostgreSQL" || database == "SqlServer")
66
using System.ClientModel.Primitives;
7+
using Microsoft.SemanticKernel.Embeddings;
8+
using SmartComponents.LocalEmbeddings.SemanticKernel;
79
//#endif
810
//#if (database == "Sqlite")
911
using Microsoft.Data.Sqlite;
@@ -67,7 +69,7 @@ public static void AddServerApiProjectServices(this WebApplicationBuilder builde
6769
services.AddScoped<PhoneService>();
6870
services.AddScoped<PhoneServiceJobsRunner>();
6971
//#if (module == "Sales" || module == "Admin")
70-
//#if (signalR == true || database == "PostgreSQL" || database == "SqlServer")
72+
//#if (database == "PostgreSQL" || database == "SqlServer")
7173
services.AddScoped<ProductEmbeddingService>();
7274
//#endif
7375
//#endif
@@ -434,6 +436,10 @@ void AddDbContext(DbContextOptionsBuilder options)
434436
Endpoint = appSettings.AI.OpenAI.EmbeddingEndpoint,
435437
Transport = new HttpClientPipelineTransport(sp.GetRequiredService<IHttpClientFactory>().CreateClient("AI"))
436438
}).AsIEmbeddingGenerator())
439+
.ConfigureOptions(options =>
440+
{
441+
configuration.GetRequiredSection("AI:EmbeddingOptions").Bind(options);
442+
})
437443
.UseLogging()
438444
.UseOpenTelemetry();
439445
// .UseDistributedCache()
@@ -446,10 +452,26 @@ void AddDbContext(DbContextOptionsBuilder options)
446452
{
447453
Transport = new Azure.Core.Pipeline.HttpClientTransport(sp.GetRequiredService<IHttpClientFactory>().CreateClient("AI"))
448454
}).AsIEmbeddingGenerator(appSettings.AI.AzureOpenAI.EmbeddingModel))
455+
.ConfigureOptions(options =>
456+
{
457+
configuration.GetRequiredSection("AI:EmbeddingOptions").Bind(options);
458+
})
449459
.UseLogging()
450460
.UseOpenTelemetry();
451461
// .UseDistributedCache()
452462
}
463+
else
464+
{
465+
services.AddEmbeddingGenerator(sp => new LocalTextEmbeddingGenerationService()
466+
.AsEmbeddingGenerator())
467+
.ConfigureOptions(options =>
468+
{
469+
configuration.GetRequiredSection("AI:EmbeddingOptions").Bind(options);
470+
})
471+
.UseLogging()
472+
.UseOpenTelemetry();
473+
// .UseDistributedCache()
474+
}
453475
//#endif
454476

455477
builder.Services.AddHangfire(configuration =>

src/Templates/Boilerplate/Bit.Boilerplate/src/Server/Boilerplate.Server.Api/Services/ProductEmbeddingService.cs

Lines changed: 63 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -11,98 +11,106 @@ namespace Boilerplate.Server.Api.Services;
1111
/// 1- Simple string matching (e.g., `Contains` method).
1212
/// 2- Full-text search using database capabilities (e.g., PostgreSQL's full-text search).
1313
/// 3- Vector-based search using embeddings (e.g., using OpenAI's embeddings).
14-
/// This service implements vector-based search using embeddings that has the following advantages:
15-
/// - More accurate search results based on semantic meaning rather than just similarity matching.
16-
/// - Multi-language support, as embeddings can capture the meaning of words across different languages.
17-
/// And has the following disadvantages:
18-
/// - Requires additional processing to generate embeddings for the text.
19-
/// - Require more storage space for embeddings compared to simple text search.
20-
/// The simple full-text search would be enough for product search case, but we have implemented the vector-based search to demonstrate how to use embeddings in the project.
14+
/// 4- Hybrid approach combining full-text search and vector-based search.
15+
/// The vector-based search is overkill for products search, but we implemented it here so you can see how to implement it in case you need it for other scenarios.
2116
/// </summary>
2217
public partial class ProductEmbeddingService
2318
{
24-
private const float SIMILARITY_THRESHOLD = 0.85f;
19+
private const float DISTANCE_THRESHOLD = 0.65f;
2520

2621
[AutoInject] private AppDbContext dbContext = default!;
27-
[AutoInject] private IWebHostEnvironment env = default!;
28-
[AutoInject] private IServiceProvider serviceProvider = default!;
22+
[AutoInject] private IEmbeddingGenerator<string, Embedding<float>> embeddingGenerator = default!;
2923

30-
public async Task<IQueryable<Product>> GetProductsBySearchQuery(string searchQuery, CancellationToken cancellationToken)
24+
public async Task<IQueryable<Product>> SearchProducts(string searchQuery, CancellationToken cancellationToken)
3125
{
32-
//#if (database != "PostgreSQL" && database != "SqlServer")
33-
// The RAG has been implemented for PostgreSQL / SQL Server only. Check out https://github.com/bitfoundation/bitplatform/blob/develop/src/Templates/Boilerplate/Bit.Boilerplate/src/Server/Boilerplate.Server.Api/Services/ProductEmbeddingService.cs
34-
return dbContext.Products.Where(p => p.Name!.Contains(searchQuery) || p.Category!.Name!.Contains(searchQuery));
35-
//#else
36-
var embeddedUserQuery = await EmbedText(searchQuery, cancellationToken);
37-
if (embeddedUserQuery is null)
38-
return dbContext.Products.Where(p => p.Name!.Contains(searchQuery) || p.Category!.Name!.Contains(searchQuery));
26+
if (AppDbContext.IsEmbeddingEnabled is false)
27+
throw new InvalidOperationException("Embeddings are not enabled. Please enable them to use this feature.");
28+
29+
// It would be a good idea to try finding products using full-text search first, and if not enough results are found, then use the vector-based search.
30+
// Note that test products data that have been seeded do not have embeddings, so searching for them will not return any results.
31+
32+
var embeddedSearchQuery = await embeddingGenerator.GenerateAsync(searchQuery, cancellationToken: cancellationToken);
33+
3934
//#if (database == "PostgreSQL")
40-
var value = new Pgvector.Vector(embeddedUserQuery.Value);
35+
var value = new Pgvector.Vector(embeddedSearchQuery.Vector);
4136
//#else
4237
//#if (IsInsideProjectTemplate == true)
4338
/*
4439
//#endif
45-
var value = new Microsoft.Data.SqlTypes.SqlVector<float>(embeddedUserQuery.Value);
40+
var value = new Microsoft.Data.SqlTypes.SqlVector<float>(embeddedSearchQuery.Vector);
4641
//#if (IsInsideProjectTemplate == true)
4742
*/
4843
//#endif
4944
//#endif
5045
return dbContext.Products
5146
//#if (database == "PostgreSQL")
52-
.Where(p => p.Embedding!.CosineDistance(value!) < SIMILARITY_THRESHOLD).OrderBy(p => p.Embedding!.CosineDistance(value!));
47+
.Where(p => p.Embedding!.CosineDistance(value!) < DISTANCE_THRESHOLD).OrderBy(p => p.Embedding!.CosineDistance(value!));
5348
//#elif (database == "SqlServer")
5449
//#if (IsInsideProjectTemplate == true)
5550
/*
5651
//#endif
57-
.Where(p => p.Embedding.HasValue && EF.Functions.VectorDistance("cosine", p.Embedding.Value, value) < SIMILARITY_THRESHOLD).OrderBy(p => EF.Functions.VectorDistance("cosine", p.Embedding!.Value, value!));
52+
.Where(p => p.Embedding.HasValue && EF.Functions.VectorDistance("cosine", p.Embedding.Value, value) < DISTANCE_THRESHOLD).OrderBy(p => EF.Functions.VectorDistance("cosine", p.Embedding!.Value, value!));
5853
//#if (IsInsideProjectTemplate == true)
5954
*/
6055
//#endif
6156
//#endif
62-
//#endif
6357
}
6458

6559
public async Task Embed(Product product, CancellationToken cancellationToken)
6660
{
67-
//#if (database != "PostgreSQL" && database != "SqlServer")
68-
return; // The RAG has been implemented for PostgreSQL / SQL Server only.
69-
//#else
70-
await dbContext.Entry(product).Reference(p => p.Category).LoadAsync(cancellationToken);
61+
if (AppDbContext.IsEmbeddingEnabled is false)
62+
throw new InvalidOperationException("Embeddings are not enabled. Please enable them to use this feature.");
63+
64+
List<(string text, float weight)> inputs = [];
7165

72-
// TODO: Needs to be improved.
73-
var embedding = await EmbedText($@"
74-
Name: **{product.Name}**
75-
Manufacture: **{product.Category!.Name}**
76-
Description: {product.DescriptionText}
77-
Appearance: {product.PrimaryImageAltText}", cancellationToken);
66+
await dbContext.Entry(product)
67+
.Reference(p => p.Category)
68+
.LoadAsync(cancellationToken);
7869

79-
if (embedding.HasValue)
70+
inputs.Add(($"Id: {product.ShortId}", 0.9f));
71+
inputs.Add(($"Name: {product.Name}", 0.9f));
72+
if (string.IsNullOrEmpty(product.DescriptionText) is false)
8073
{
81-
product.Embedding = new(embedding.Value);
74+
inputs.Add((product.DescriptionText, 0.7f));
8275
}
83-
//#endif
84-
}
76+
if (string.IsNullOrEmpty(product.PrimaryImageAltText) is false)
77+
{
78+
inputs.Add((product.PrimaryImageAltText, 0.5f));
79+
}
80+
inputs.Add((product.Category!.Name!, 0.9f));
8581

86-
private async Task<ReadOnlyMemory<float>?> EmbedText(string input, CancellationToken cancellationToken)
87-
{
88-
//#if (database != "PostgreSQL" && database != "SqlServer")
89-
return null; // The RAG has been implemented for PostgreSQL / SQL Server only.
90-
//#else
91-
if (AppDbContext.IsEmbeddingEnabled is false)
92-
return null;
93-
var embeddingGenerator = serviceProvider.GetService<IEmbeddingGenerator<string, Embedding<float>>>();
94-
if (embeddingGenerator is null)
95-
return env.IsDevelopment() ? null : throw new InvalidOperationException("Embedding generator is not registered.");
82+
var texts = inputs.Select(i => i.text).ToArray();
9683

97-
input = $@"
98-
Name: **{input}**
99-
Manufacture: **{input}**
100-
Description: {input}
101-
Appearance: {input}";
84+
var embeddingsResponse = await embeddingGenerator.GenerateAsync(texts, cancellationToken: cancellationToken);
10285

86+
var vectors = embeddingsResponse.Select(e => e.Vector.ToArray()).ToArray();
87+
var weights = inputs.Select(t => t.weight).ToArray();
10388

104-
var embedding = await embeddingGenerator.GenerateVectorAsync(input, options: new() { }, cancellationToken);
105-
return embedding.ToArray();
106-
//#endif
89+
if (vectors.Any(v => v.Length != vectors[0].Length))
90+
{
91+
throw new InvalidOperationException("All embedding vectors must have the same length.");
92+
}
93+
94+
var embedding = new float[vectors[0].Length];
95+
for (int i = 0; i < embedding.Length; i++)
96+
{
97+
embedding[i] = 0f;
98+
for (int j = 0; j < vectors.Length; j++)
99+
{
100+
embedding[i] += weights[j] * vectors[j][i];
101+
}
102+
}
103+
104+
// L2 normalize the embedding for cosine distance stability
105+
float norm = (float)Math.Sqrt(embedding.Sum(v => v * v));
106+
if (norm > 0)
107+
{
108+
for (int i = 0; i < embedding.Length; i++)
109+
{
110+
embedding[i] /= norm;
111+
}
112+
}
113+
114+
product.Embedding = new(embedding);
107115
}
108116
}

0 commit comments

Comments
 (0)