diff --git a/app/Directory.Packages.props b/app/Directory.Packages.props index 02ca95bd..968e5de6 100644 --- a/app/Directory.Packages.props +++ b/app/Directory.Packages.props @@ -39,6 +39,7 @@ + diff --git a/app/shared/Shared/Services/AzureSearchEmbedService.cs b/app/shared/Shared/Services/AzureSearchEmbedService.cs index ab31d2b8..d50951b5 100644 --- a/app/shared/Shared/Services/AzureSearchEmbedService.cs +++ b/app/shared/Shared/Services/AzureSearchEmbedService.cs @@ -13,6 +13,7 @@ using Azure.Storage.Blobs; using Azure.Storage.Blobs.Models; using Microsoft.Extensions.Logging; +using Microsoft.ML.Tokenizers; using Shared.Models; public sealed partial class AzureSearchEmbedService( @@ -25,11 +26,16 @@ public sealed partial class AzureSearchEmbedService( BlobContainerClient corpusContainerClient, IComputerVisionService? computerVisionService = null, bool includeImageEmbeddingsField = false, - ILogger? logger = null) : IEmbedService + ILogger? logger = null, + string modelEncodingName = "gpt-4", + int maxTokensPerSection = 500) : IEmbedService { [GeneratedRegex("[^0-9a-zA-Z_-]")] private static partial Regex MatchInSetRegex(); + private static readonly char[] s_sentenceEndings = ['.', '。', '.', '!', '?', '‼', '⁇', '⁈', '⁉']; + private const int DefaultOverlapPercent = 10; + public async Task EmbedPDFBlobAsync(Stream pdfBlobStream, string blobName) { try @@ -48,7 +54,11 @@ public async Task EmbedPDFBlobAsync(Stream pdfBlobStream, string blobName) await UploadCorpusAsync(corpusName, page.Text); } - var sections = CreateSections(pageMap, blobName); + List
sections = []; + await foreach(var section in CreateSectionsAsync(pageMap, blobName)) + { + sections.Add(section); + } var infoLoggingEnabled = logger?.IsEnabled(LogLevel.Information); if (infoLoggingEnabled is true) @@ -109,7 +119,7 @@ public async Task EmbedImageBlobAsync( return true; } - public async Task CreateSearchIndexAsync(string searchIndexName, CancellationToken ct = default) + public async Task CreateSearchIndexAsync(string searchIndexName, CancellationToken ct = default) { string vectorSearchConfigName = "my-vector-config"; string vectorSearchProfile = "my-vector-profile"; @@ -315,18 +325,97 @@ private async Task UploadCorpusAsync(string corpusBlobName, string text) }); } - public IEnumerable
CreateSections( + public async IAsyncEnumerable
SplitSectionByTokenLengthAsync(string Id, + string Content, + string SourcePage, + string SourceFile, + string? Category = null, + Tokenizer? tokenizer = null) + { + tokenizer ??= await Tiktoken.CreateByModelNameAsync(modelEncodingName); + if (tokenizer.CountTokens(Content) <= maxTokensPerSection) + { + yield return new Section( + Id: Id, + Content: Content, + SourcePage: SourcePage, + SourceFile: SourceFile, + Category: Category); + } else + { + int start = Content.Length / 2; + int pos = 0; + int boundary = Content.Length / 3; + int splitPosition = -1; + + while (start - pos > boundary) + { + if (s_sentenceEndings.Contains(Content[start - pos])) + { + splitPosition = start - pos; + break; + } + else if (s_sentenceEndings.Contains(Content[start + pos])) + { + splitPosition = start + pos; + break; + } + else + { + pos += 1; + } + } + + string firstHalf, secondHalf; + + if (splitPosition > 0) + { + firstHalf = Content[..(splitPosition + 1)]; + secondHalf = Content[(splitPosition + 1)..]; + } + else + { + // Split page in half and call function again + // Overlap first and second halves by DEFAULT_OVERLAP_PERCENT% + int middle = Content.Length / 2; + int overlapChars = (int)(Content.Length * (DefaultOverlapPercent / 100.0)); + firstHalf = Content[..(middle + overlapChars)]; + secondHalf = Content[(middle - overlapChars)..]; + } + + await foreach(var section in SplitSectionByTokenLengthAsync(Id, firstHalf, SourcePage, SourceFile, Category, tokenizer)) + { + yield return section; + } + await foreach (var section in SplitSectionByTokenLengthAsync(Id, secondHalf, SourcePage, SourceFile, Category, tokenizer)) + { + yield return section; + } + } + + } + + public async IAsyncEnumerable
CreateSectionsAsync( IReadOnlyList pageMap, string blobName) { const int MaxSectionLength = 1_000; const int SentenceSearchLimit = 100; const int SectionOverlap = 100; - var sentenceEndings = new[] { '.', '!', '?' }; - var wordBreaks = new[] { ',', ';', ':', ' ', '(', ')', '[', ']', '{', '}', '\t', '\n' }; + var wordBreaks = new[] { ',', '、', ';', ':', ' ', '(', ')', '[', ']', '{', '}', '\t', '\n' }; var allText = string.Concat(pageMap.Select(p => p.Text)); var length = allText.Length; var start = 0; + + if (length <= MaxSectionLength) + { + await foreach (var section in SplitSectionByTokenLengthAsync( + Id: MatchInSetRegex().Replace($"{blobName}-{start}", "_").TrimStart('_'), + Content: allText, + SourcePage: BlobNameFromFilePage(blobName, FindPage(pageMap, start)), + SourceFile: blobName)) { yield return section; } + } + var end = length; logger?.LogInformation("Splitting '{BlobName}' into sections", blobName); @@ -343,7 +432,7 @@ public IEnumerable
CreateSections( else { // Try to find the end of the sentence - while (end < length && (end - start - MaxSectionLength) < SentenceSearchLimit && !sentenceEndings.Contains(allText[end])) + while (end < length && (end - start - MaxSectionLength) < SentenceSearchLimit && !s_sentenceEndings.Contains(allText[end])) { if (wordBreaks.Contains(allText[end])) { @@ -352,7 +441,7 @@ public IEnumerable
CreateSections( end++; } - if (end < length && !sentenceEndings.Contains(allText[end]) && lastWord > 0) + if (end < length && !s_sentenceEndings.Contains(allText[end]) && lastWord > 0) { end = lastWord; // Fall back to at least keeping a whole word } @@ -366,7 +455,7 @@ public IEnumerable
CreateSections( // Try to find the start of the sentence or at least a whole word boundary lastWord = -1; while (start > 0 && start > end - MaxSectionLength - - (2 * SentenceSearchLimit) && !sentenceEndings.Contains(allText[start])) + (2 * SentenceSearchLimit) && !s_sentenceEndings.Contains(allText[start])) { if (wordBreaks.Contains(allText[start])) { @@ -375,7 +464,7 @@ public IEnumerable
CreateSections( start--; } - if (!sentenceEndings.Contains(allText[start]) && lastWord > 0) + if (!s_sentenceEndings.Contains(allText[start]) && lastWord > 0) { start = lastWord; } @@ -386,11 +475,11 @@ public IEnumerable
CreateSections( var sectionText = allText[start..end]; - yield return new Section( + await foreach (var section in SplitSectionByTokenLengthAsync( Id: MatchInSetRegex().Replace($"{blobName}-{start}", "_").TrimStart('_'), Content: sectionText, SourcePage: BlobNameFromFilePage(blobName, FindPage(pageMap, start)), - SourceFile: blobName); + SourceFile: blobName)) { yield return section; } var lastTableStart = sectionText.LastIndexOf(" 2 * SentenceSearchLimit && lastTableStart > sectionText.LastIndexOf(" + diff --git a/app/tests/MinimalApi.Tests/AzureSearchEmbedServiceTest.cs b/app/tests/MinimalApi.Tests/AzureSearchEmbedServiceTest.cs index af356f76..7a9b3791 100644 --- a/app/tests/MinimalApi.Tests/AzureSearchEmbedServiceTest.cs +++ b/app/tests/MinimalApi.Tests/AzureSearchEmbedServiceTest.cs @@ -16,7 +16,10 @@ using Azure.Storage.Blobs; using FluentAssertions; using Microsoft.Extensions.Logging; +using MudBlazor.Services; using NSubstitute; +using Shared.Models; +using Xunit; namespace MinimalApi.Tests; public class AzureSearchEmbedServiceTest @@ -312,4 +315,89 @@ public async Task EmbedImageBlobTestAsync() await blobServiceClient.DeleteBlobContainerAsync(blobContainer); } } + + [Fact] + public async Task EnsureTextSplitsOnTinySectionAsync() + { + var indexName = nameof(EnsureSearchIndexWithoutImageEmbeddingsAsync).ToLower(); + var openAIEndpoint = "https://fake.openai.azure.com"; + var embeddingDeployment = "gpt-4"; + var azureSearchEndpoint = "https://fake-search.search.azure.com"; + var blobEndpoint = "https://fake-storage.azure.com/"; + var blobContainer = "test"; + + var azureCredential = new DefaultAzureCredential(); + var openAIClient = new OpenAIClient(new Uri(openAIEndpoint), azureCredential); + var searchClient = new SearchClient(new Uri(azureSearchEndpoint), indexName, azureCredential); + var searchIndexClient = new SearchIndexClient(new Uri(azureSearchEndpoint), azureCredential); + var documentAnalysisClient = new DocumentAnalysisClient(new Uri(azureSearchEndpoint), azureCredential); + var blobServiceClient = new BlobServiceClient(new Uri(blobEndpoint), azureCredential); + + var service = new AzureSearchEmbedService( + openAIClient: openAIClient, + embeddingModelName: embeddingDeployment, + searchClient: searchClient, + searchIndexName: indexName, + searchIndexClient: searchIndexClient, + documentAnalysisClient: documentAnalysisClient, + corpusContainerClient: blobServiceClient.GetBlobContainerClient(blobContainer), + computerVisionService: null, + includeImageEmbeddingsField: false, + logger: null, + maxTokensPerSection: 500); + List
sections = []; + IReadOnlyList pageMap = + [ + new(0, 0, "this is a test") + ]; + await foreach (var section in service.CreateSectionsAsync(pageMap, "test-blob")) + { + sections.Add(section); + } + sections.Should().HaveCount(1); + sections[0].Content.Should().Be("this is a test"); + } + + [Fact] + public async Task EnsureTextSplitsOnBigSectionAsync() + { + var indexName = nameof(EnsureSearchIndexWithoutImageEmbeddingsAsync).ToLower(); + var openAIEndpoint = "https://fake.openai.azure.com"; + var embeddingDeployment = "embeddings"; + var azureSearchEndpoint = "https://fake-search.search.azure.com"; + var blobEndpoint = "https://fake-storage.azure.com/"; + var blobContainer = "test"; + + var azureCredential = new DefaultAzureCredential(); + var openAIClient = new OpenAIClient(new Uri(openAIEndpoint), azureCredential); + var searchClient = new SearchClient(new Uri(azureSearchEndpoint), indexName, azureCredential); + var searchIndexClient = new SearchIndexClient(new Uri(azureSearchEndpoint), azureCredential); + var documentAnalysisClient = new DocumentAnalysisClient(new Uri(azureSearchEndpoint), azureCredential); + var blobServiceClient = new BlobServiceClient(new Uri(blobEndpoint), azureCredential); + + var service = new AzureSearchEmbedService( + openAIClient: openAIClient, + embeddingModelName: embeddingDeployment, + searchClient: searchClient, + searchIndexName: indexName, + searchIndexClient: searchIndexClient, + documentAnalysisClient: documentAnalysisClient, + corpusContainerClient: blobServiceClient.GetBlobContainerClient(blobContainer), + computerVisionService: null, + includeImageEmbeddingsField: false, + logger: null, + maxTokensPerSection: 500); + List
sections = []; + string testContent = "".PadRight(1000, ' '); + IReadOnlyList pageMap = + [ + new(0, 0, testContent) + ]; + await foreach (var section in service.CreateSectionsAsync(pageMap, "test-blob")) + { + sections.Add(section); + } + sections.Should().HaveCount(1); + sections[0].Content.Should().Be(testContent); + } }