Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for embedding dimensions and fix prepdocs on Windows #342

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions app/backend/Services/ReadRetrieveReadChatService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,27 @@ public ReadRetrieveReadChatService(
kernelBuilder = kernelBuilder.AddOpenAIChatCompletion(deployment, client);

var embeddingModelName = configuration["OpenAiEmbeddingDeployment"];
int embeddingModelDimensions;
if (!int.TryParse(configuration["AzureOpenAiEmbeddingModelDimensions"], out embeddingModelDimensions)) {
embeddingModelDimensions = 1536;
}
ArgumentNullException.ThrowIfNullOrWhiteSpace(embeddingModelName);
kernelBuilder = kernelBuilder.AddOpenAITextEmbeddingGeneration(embeddingModelName, client);
kernelBuilder = kernelBuilder.AddOpenAITextEmbeddingGeneration(embeddingModelName, client, dimensions: embeddingModelDimensions);
}
else
{
var deployedModelName = configuration["AzureOpenAiChatGptDeployment"];
ArgumentNullException.ThrowIfNullOrWhiteSpace(deployedModelName);
var embeddingModelName = configuration["AzureOpenAiEmbeddingDeployment"];
int embeddingModelDimensions;
if (!int.TryParse(configuration["AzureOpenAiEmbeddingModelDimensions"], out embeddingModelDimensions)) {
embeddingModelDimensions = 1536;
}
if (!string.IsNullOrEmpty(embeddingModelName))
{
var endpoint = configuration["AzureOpenAiServiceEndpoint"];
ArgumentNullException.ThrowIfNullOrWhiteSpace(endpoint);
kernelBuilder = kernelBuilder.AddAzureOpenAITextEmbeddingGeneration(embeddingModelName, endpoint, tokenCredential ?? new DefaultAzureCredential());
kernelBuilder = kernelBuilder.AddAzureOpenAITextEmbeddingGeneration(embeddingModelName, endpoint, tokenCredential ?? new DefaultAzureCredential(), dimensions: embeddingModelDimensions);
kernelBuilder = kernelBuilder.AddAzureOpenAIChatCompletion(deployedModelName, endpoint, tokenCredential ?? new DefaultAzureCredential());
}
}
Expand Down
10 changes: 10 additions & 0 deletions app/functions/EmbedFunctions/Program.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft. All rights reserved.

using Azure.AI.OpenAI;
using Microsoft.Extensions.Logging.Abstractions;

var host = new HostBuilder()
.ConfigureServices(services =>
Expand Down Expand Up @@ -65,17 +66,24 @@ uri is not null

OpenAIClient? openAIClient = null;
string? embeddingModelName = null;
int embeddingModelDimensions = -1;

if (useAOAI)
{
var openaiEndPoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new ArgumentNullException("AZURE_OPENAI_ENDPOINT is null");
embeddingModelName = Environment.GetEnvironmentVariable("AZURE_OPENAI_EMBEDDING_DEPLOYMENT") ?? throw new ArgumentNullException("AZURE_OPENAI_EMBEDDING_DEPLOYMENT is null");
if (!int.TryParse(Environment.GetEnvironmentVariable("AZURE_OPENAI_EMBEDDING_MODEL_DIMENSIONS"), out embeddingModelDimensions)) {
embeddingModelDimensions = 1536;
}
openAIClient = new OpenAIClient(new Uri(openaiEndPoint), new DefaultAzureCredential());
}
else
{
embeddingModelName = Environment.GetEnvironmentVariable("OPENAI_EMBEDDING_DEPLOYMENT") ?? throw new ArgumentNullException("OPENAI_EMBEDDING_DEPLOYMENT is null");
var openaiKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new ArgumentNullException("OPENAI_API_KEY is null");
if (!int.TryParse(Environment.GetEnvironmentVariable("AZURE_OPENAI_EMBEDDING_MODEL_DIMENSIONS"), out embeddingModelDimensions)) {
embeddingModelDimensions = 1536;
}
openAIClient = new OpenAIClient(openaiKey);
}

Expand All @@ -94,6 +102,7 @@ uri is not null
return new AzureSearchEmbedService(
openAIClient: openAIClient,
embeddingModelName: embeddingModelName,
embeddingModelDimensions: embeddingModelDimensions,
searchClient: searchClient,
searchIndexName: searchIndexName,
searchIndexClient: searchIndexClient,
Expand All @@ -108,6 +117,7 @@ uri is not null
return new AzureSearchEmbedService(
openAIClient: openAIClient,
embeddingModelName: embeddingModelName,
embeddingModelDimensions: embeddingModelDimensions,
searchClient: searchClient,
searchIndexName: searchIndexName,
searchIndexClient: searchIndexClient,
Expand Down
3 changes: 1 addition & 2 deletions app/map-env.ps1
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@ foreach ($line in (& azd env get-values)) {
if ($line -match "([^=]+)=(.*)") {
$key = $matches[1]
$value = $matches[2] -replace '^"|"$'
[Environment]::SetEnvironmentVariable(
$key, $value, [System.EnvironmentVariableTarget]::User)
Set-Item "env:$key" $value
}
}

Expand Down
1 change: 1 addition & 0 deletions app/prepdocs/PrepareDocs/AppOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ internal record class AppOptions(
string? AzureOpenAIServiceEndpoint,
string? SearchIndexName,
string? EmbeddingModelName,
int EmbeddingModelDimensions,
bool Remove,
bool RemoveAll,
string? FormRecognizerServiceEndpoint,
Expand Down
2 changes: 2 additions & 0 deletions app/prepdocs/PrepareDocs/Program.Clients.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,14 @@ private static Task<AzureSearchEmbedService> GetAzureSearchEmbedService(AppOptio
var blobContainerClient = await GetCorpusBlobContainerClientAsync(o);
var openAIClient = await GetOpenAIClientAsync(o);
var embeddingModelName = o.EmbeddingModelName ?? throw new ArgumentNullException(nameof(o.EmbeddingModelName));
var embeddingModelDimensions = o.EmbeddingModelDimensions;
var searchIndexName = o.SearchIndexName ?? throw new ArgumentNullException(nameof(o.SearchIndexName));
var computerVisionService = await GetComputerVisionServiceAsync(o);

return new AzureSearchEmbedService(
openAIClient: openAIClient,
embeddingModelName: embeddingModelName,
embeddingModelDimensions: embeddingModelDimensions,
searchClient: searchClient,
searchIndexName: searchIndexName,
searchIndexClient: searchIndexClient,
Expand Down
5 changes: 5 additions & 0 deletions app/prepdocs/PrepareDocs/Program.Options.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ internal static partial class Program
private static readonly Option<string> s_embeddingModelName =
new(name: "--embeddingmodel", description: "Optional. Name of the Azure AI Search embedding model to use for embedding content in the search index (will be created if it doesn't exist)");

private static readonly Option<int> s_embeddingModelDimensions =
new(name: "--embeddingmodeldimensions", description: "Optional. Name of the Azure AI Search embedding model dimensions to use for embedding content in the search index");

private static readonly Option<bool> s_remove =
new(name: "--remove", description: "Remove references to this document from blob storage and the search index");

Expand Down Expand Up @@ -63,6 +66,7 @@ internal static partial class Program
s_searchIndexName,
s_azureOpenAIService,
s_embeddingModelName,
s_embeddingModelDimensions,
s_remove,
s_removeAll,
s_formRecognizerServiceEndpoint,
Expand All @@ -81,6 +85,7 @@ internal static partial class Program
SearchIndexName: context.ParseResult.GetValueForOption(s_searchIndexName),
AzureOpenAIServiceEndpoint: context.ParseResult.GetValueForOption(s_azureOpenAIService),
EmbeddingModelName: context.ParseResult.GetValueForOption(s_embeddingModelName),
EmbeddingModelDimensions: context.ParseResult.GetValueForOption(s_embeddingModelDimensions),
Remove: context.ParseResult.GetValueForOption(s_remove),
RemoveAll: context.ParseResult.GetValueForOption(s_removeAll),
FormRecognizerServiceEndpoint: context.ParseResult.GetValueForOption(s_formRecognizerServiceEndpoint),
Expand Down
5 changes: 3 additions & 2 deletions app/shared/Shared/Services/AzureSearchEmbedService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
public sealed partial class AzureSearchEmbedService(
OpenAIClient openAIClient,
string embeddingModelName,
int embeddingModelDimensions,
SearchClient searchClient,
string searchIndexName,
SearchIndexClient searchIndexClient,
Expand Down Expand Up @@ -135,7 +136,7 @@ public async Task CreateSearchIndexAsync(string searchIndexName, CancellationTok
new SimpleField("sourcefile", SearchFieldDataType.String) { IsFacetable = true },
new SearchField("embedding", SearchFieldDataType.Collection(SearchFieldDataType.Single))
{
VectorSearchDimensions = 1536,
VectorSearchDimensions = embeddingModelDimensions,
IsSearchable = true,
VectorSearchProfileName = vectorSearchProfile,
}
Expand Down Expand Up @@ -449,7 +450,7 @@ private async Task IndexSectionsAsync(IEnumerable<Section> sections)
var batch = new IndexDocumentsBatch<SearchDocument>();
foreach (var section in sections)
{
var embeddings = await openAIClient.GetEmbeddingsAsync(new Azure.AI.OpenAI.EmbeddingsOptions(embeddingModelName, [section.Content.Replace('\r', ' ')]));
var embeddings = await openAIClient.GetEmbeddingsAsync(new Azure.AI.OpenAI.EmbeddingsOptions(embeddingModelName, [section.Content.Replace('\r', ' ')]) { Dimensions = embeddingModelDimensions });
var embedding = embeddings.Value.Data.FirstOrDefault()?.Embedding.ToArray() ?? [];
batch.Actions.Add(new IndexDocumentsAction<SearchDocument>(
IndexActionType.MergeOrUpload,
Expand Down
28 changes: 28 additions & 0 deletions app/tests/MinimalApi.Tests/AzureDocumentSearchServiceTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,34 @@ public async Task QueryDocumentsTestTextOnlyAsync()
records.Count().Should().Be(3);
}

[EnvironmentVariablesFact("AZURE_SEARCH_INDEX", "AZURE_SEARCH_SERVICE_ENDPOINT", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_EMBEDDING_DEPLOYMENT", "AZURE_OPENAI_EMBEDDING_MODEL_DIMENSIONS")]
public async Task QueryDocumentsTestEmbeddingOnlyDimensionsAsync()
{
var index = Environment.GetEnvironmentVariable("AZURE_SEARCH_INDEX") ?? throw new InvalidOperationException();
var searchServceEndpoint = Environment.GetEnvironmentVariable("AZURE_SEARCH_SERVICE_ENDPOINT") ?? throw new InvalidOperationException();
var openAiEndpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new InvalidOperationException();
var openAiEmbeddingDeployment = Environment.GetEnvironmentVariable("AZURE_OPENAI_EMBEDDING_DEPLOYMENT") ?? throw new InvalidOperationException();
var openAiEmbeddingModelDimensions = int.Parse(Environment.GetEnvironmentVariable("AZURE_OPENAI_EMBEDDING_MODEL_DIMENSIONS") ?? "");
var openAIClient = new OpenAIClient(new Uri(openAiEndpoint), new DefaultAzureCredential());
var query = "What is included in my Northwind Health Plus plan that is not in standard?";
var embeddingResponse = await openAIClient.GetEmbeddingsAsync(new EmbeddingsOptions(openAiEmbeddingDeployment, [query]) { Dimensions = openAiEmbeddingModelDimensions });
var embedding = embeddingResponse.Value.Data.First().Embedding;
var searchClient = new SearchClient(new Uri(searchServceEndpoint), index, new DefaultAzureCredential());
var service = new AzureSearchService(searchClient);

// query only
var option = new RequestOverrides
{
RetrievalMode = RetrievalMode.Vector,
Top = 3,
SemanticCaptions = true,
SemanticRanker = true,
};

var records = await service.QueryDocumentsAsync(query: query, embedding: embedding.ToArray(), overrides: option);
records.Count().Should().Be(3);
}

[EnvironmentVariablesFact("AZURE_SEARCH_INDEX", "AZURE_SEARCH_SERVICE_ENDPOINT", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_EMBEDDING_DEPLOYMENT")]
public async Task QueryDocumentsTestEmbeddingOnlyAsync()
{
Expand Down
Loading
Loading