diff --git a/app/Directory.Packages.props b/app/Directory.Packages.props
index 02ca95bd..968e5de6 100644
--- a/app/Directory.Packages.props
+++ b/app/Directory.Packages.props
@@ -39,6 +39,7 @@
+
diff --git a/app/shared/Shared/Services/AzureSearchEmbedService.cs b/app/shared/Shared/Services/AzureSearchEmbedService.cs
index ab31d2b8..d50951b5 100644
--- a/app/shared/Shared/Services/AzureSearchEmbedService.cs
+++ b/app/shared/Shared/Services/AzureSearchEmbedService.cs
@@ -13,6 +13,7 @@
using Azure.Storage.Blobs;
using Azure.Storage.Blobs.Models;
using Microsoft.Extensions.Logging;
+using Microsoft.ML.Tokenizers;
using Shared.Models;
public sealed partial class AzureSearchEmbedService(
@@ -25,11 +26,16 @@ public sealed partial class AzureSearchEmbedService(
BlobContainerClient corpusContainerClient,
IComputerVisionService? computerVisionService = null,
bool includeImageEmbeddingsField = false,
- ILogger? logger = null) : IEmbedService
+ ILogger? logger = null,
+ string modelEncodingName = "gpt-4",
+ int maxTokensPerSection = 500) : IEmbedService
{
[GeneratedRegex("[^0-9a-zA-Z_-]")]
private static partial Regex MatchInSetRegex();
+ private static readonly char[] s_sentenceEndings = ['.', '。', '.', '!', '?', '‼', '⁇', '⁈', '⁉'];
+ private const int DefaultOverlapPercent = 10;
+
public async Task EmbedPDFBlobAsync(Stream pdfBlobStream, string blobName)
{
try
@@ -48,7 +54,11 @@ public async Task EmbedPDFBlobAsync(Stream pdfBlobStream, string blobName)
await UploadCorpusAsync(corpusName, page.Text);
}
- var sections = CreateSections(pageMap, blobName);
+ List sections = [];
+ await foreach(var section in CreateSectionsAsync(pageMap, blobName))
+ {
+ sections.Add(section);
+ }
var infoLoggingEnabled = logger?.IsEnabled(LogLevel.Information);
if (infoLoggingEnabled is true)
@@ -109,7 +119,7 @@ public async Task EmbedImageBlobAsync(
return true;
}
- public async Task CreateSearchIndexAsync(string searchIndexName, CancellationToken ct = default)
+ public async Task CreateSearchIndexAsync(string searchIndexName, CancellationToken ct = default)
{
string vectorSearchConfigName = "my-vector-config";
string vectorSearchProfile = "my-vector-profile";
@@ -315,18 +325,97 @@ private async Task UploadCorpusAsync(string corpusBlobName, string text)
});
}
- public IEnumerable CreateSections(
+ public async IAsyncEnumerable SplitSectionByTokenLengthAsync(string Id,
+ string Content,
+ string SourcePage,
+ string SourceFile,
+ string? Category = null,
+ Tokenizer? tokenizer = null)
+ {
+ tokenizer ??= await Tiktoken.CreateByModelNameAsync(modelEncodingName);
+ if (tokenizer.CountTokens(Content) <= maxTokensPerSection)
+ {
+ yield return new Section(
+ Id: Id,
+ Content: Content,
+ SourcePage: SourcePage,
+ SourceFile: SourceFile,
+ Category: Category);
+ } else
+ {
+ int start = Content.Length / 2;
+ int pos = 0;
+ int boundary = Content.Length / 3;
+ int splitPosition = -1;
+
+ while (start - pos > boundary)
+ {
+ if (s_sentenceEndings.Contains(Content[start - pos]))
+ {
+ splitPosition = start - pos;
+ break;
+ }
+ else if (s_sentenceEndings.Contains(Content[start + pos]))
+ {
+ splitPosition = start + pos;
+ break;
+ }
+ else
+ {
+ pos += 1;
+ }
+ }
+
+ string firstHalf, secondHalf;
+
+ if (splitPosition > 0)
+ {
+ firstHalf = Content[..(splitPosition + 1)];
+ secondHalf = Content[(splitPosition + 1)..];
+ }
+ else
+ {
+ // Split page in half and call function again
+ // Overlap first and second halves by DEFAULT_OVERLAP_PERCENT%
+ int middle = Content.Length / 2;
+ int overlapChars = (int)(Content.Length * (DefaultOverlapPercent / 100.0));
+ firstHalf = Content[..(middle + overlapChars)];
+ secondHalf = Content[(middle - overlapChars)..];
+ }
+
+ await foreach(var section in SplitSectionByTokenLengthAsync(Id, firstHalf, SourcePage, SourceFile, Category, tokenizer))
+ {
+ yield return section;
+ }
+ await foreach (var section in SplitSectionByTokenLengthAsync(Id, secondHalf, SourcePage, SourceFile, Category, tokenizer))
+ {
+ yield return section;
+ }
+ }
+
+ }
+
+ public async IAsyncEnumerable CreateSectionsAsync(
IReadOnlyList pageMap, string blobName)
{
const int MaxSectionLength = 1_000;
const int SentenceSearchLimit = 100;
const int SectionOverlap = 100;
- var sentenceEndings = new[] { '.', '!', '?' };
- var wordBreaks = new[] { ',', ';', ':', ' ', '(', ')', '[', ']', '{', '}', '\t', '\n' };
+ var wordBreaks = new[] { ',', '、', ';', ':', ' ', '(', ')', '[', ']', '{', '}', '\t', '\n' };
var allText = string.Concat(pageMap.Select(p => p.Text));
var length = allText.Length;
var start = 0;
+
+ if (length <= MaxSectionLength)
+ {
+ await foreach (var section in SplitSectionByTokenLengthAsync(
+ Id: MatchInSetRegex().Replace($"{blobName}-{start}", "_").TrimStart('_'),
+ Content: allText,
+ SourcePage: BlobNameFromFilePage(blobName, FindPage(pageMap, start)),
+ SourceFile: blobName)) { yield return section; }
+ }
+
var end = length;
logger?.LogInformation("Splitting '{BlobName}' into sections", blobName);
@@ -343,7 +432,7 @@ public IEnumerable CreateSections(
else
{
// Try to find the end of the sentence
- while (end < length && (end - start - MaxSectionLength) < SentenceSearchLimit && !sentenceEndings.Contains(allText[end]))
+ while (end < length && (end - start - MaxSectionLength) < SentenceSearchLimit && !s_sentenceEndings.Contains(allText[end]))
{
if (wordBreaks.Contains(allText[end]))
{
@@ -352,7 +441,7 @@ public IEnumerable CreateSections(
end++;
}
- if (end < length && !sentenceEndings.Contains(allText[end]) && lastWord > 0)
+ if (end < length && !s_sentenceEndings.Contains(allText[end]) && lastWord > 0)
{
end = lastWord; // Fall back to at least keeping a whole word
}
@@ -366,7 +455,7 @@ public IEnumerable CreateSections(
// Try to find the start of the sentence or at least a whole word boundary
lastWord = -1;
while (start > 0 && start > end - MaxSectionLength -
- (2 * SentenceSearchLimit) && !sentenceEndings.Contains(allText[start]))
+ (2 * SentenceSearchLimit) && !s_sentenceEndings.Contains(allText[start]))
{
if (wordBreaks.Contains(allText[start]))
{
@@ -375,7 +464,7 @@ public IEnumerable CreateSections(
start--;
}
- if (!sentenceEndings.Contains(allText[start]) && lastWord > 0)
+ if (!s_sentenceEndings.Contains(allText[start]) && lastWord > 0)
{
start = lastWord;
}
@@ -386,11 +475,11 @@ public IEnumerable CreateSections(
var sectionText = allText[start..end];
- yield return new Section(
+ await foreach (var section in SplitSectionByTokenLengthAsync(
Id: MatchInSetRegex().Replace($"{blobName}-{start}", "_").TrimStart('_'),
Content: sectionText,
SourcePage: BlobNameFromFilePage(blobName, FindPage(pageMap, start)),
- SourceFile: blobName);
+ SourceFile: blobName)) { yield return section; }
var lastTableStart = sectionText.LastIndexOf(" 2 * SentenceSearchLimit && lastTableStart > sectionText.LastIndexOf("
+
diff --git a/app/tests/MinimalApi.Tests/AzureSearchEmbedServiceTest.cs b/app/tests/MinimalApi.Tests/AzureSearchEmbedServiceTest.cs
index af356f76..7a9b3791 100644
--- a/app/tests/MinimalApi.Tests/AzureSearchEmbedServiceTest.cs
+++ b/app/tests/MinimalApi.Tests/AzureSearchEmbedServiceTest.cs
@@ -16,7 +16,10 @@
using Azure.Storage.Blobs;
using FluentAssertions;
using Microsoft.Extensions.Logging;
+using MudBlazor.Services;
using NSubstitute;
+using Shared.Models;
+using Xunit;
namespace MinimalApi.Tests;
public class AzureSearchEmbedServiceTest
@@ -312,4 +315,89 @@ public async Task EmbedImageBlobTestAsync()
await blobServiceClient.DeleteBlobContainerAsync(blobContainer);
}
}
+
+ [Fact]
+ public async Task EnsureTextSplitsOnTinySectionAsync()
+ {
+ var indexName = nameof(EnsureSearchIndexWithoutImageEmbeddingsAsync).ToLower();
+ var openAIEndpoint = "https://fake.openai.azure.com";
+ var embeddingDeployment = "gpt-4";
+ var azureSearchEndpoint = "https://fake-search.search.azure.com";
+ var blobEndpoint = "https://fake-storage.azure.com/";
+ var blobContainer = "test";
+
+ var azureCredential = new DefaultAzureCredential();
+ var openAIClient = new OpenAIClient(new Uri(openAIEndpoint), azureCredential);
+ var searchClient = new SearchClient(new Uri(azureSearchEndpoint), indexName, azureCredential);
+ var searchIndexClient = new SearchIndexClient(new Uri(azureSearchEndpoint), azureCredential);
+ var documentAnalysisClient = new DocumentAnalysisClient(new Uri(azureSearchEndpoint), azureCredential);
+ var blobServiceClient = new BlobServiceClient(new Uri(blobEndpoint), azureCredential);
+
+ var service = new AzureSearchEmbedService(
+ openAIClient: openAIClient,
+ embeddingModelName: embeddingDeployment,
+ searchClient: searchClient,
+ searchIndexName: indexName,
+ searchIndexClient: searchIndexClient,
+ documentAnalysisClient: documentAnalysisClient,
+ corpusContainerClient: blobServiceClient.GetBlobContainerClient(blobContainer),
+ computerVisionService: null,
+ includeImageEmbeddingsField: false,
+ logger: null,
+ maxTokensPerSection: 500);
+ List sections = [];
+ IReadOnlyList pageMap =
+ [
+ new(0, 0, "this is a test")
+ ];
+ await foreach (var section in service.CreateSectionsAsync(pageMap, "test-blob"))
+ {
+ sections.Add(section);
+ }
+ sections.Should().HaveCount(1);
+ sections[0].Content.Should().Be("this is a test");
+ }
+
+ [Fact]
+ public async Task EnsureTextSplitsOnBigSectionAsync()
+ {
+ var indexName = nameof(EnsureSearchIndexWithoutImageEmbeddingsAsync).ToLower();
+ var openAIEndpoint = "https://fake.openai.azure.com";
+ var embeddingDeployment = "embeddings";
+ var azureSearchEndpoint = "https://fake-search.search.azure.com";
+ var blobEndpoint = "https://fake-storage.azure.com/";
+ var blobContainer = "test";
+
+ var azureCredential = new DefaultAzureCredential();
+ var openAIClient = new OpenAIClient(new Uri(openAIEndpoint), azureCredential);
+ var searchClient = new SearchClient(new Uri(azureSearchEndpoint), indexName, azureCredential);
+ var searchIndexClient = new SearchIndexClient(new Uri(azureSearchEndpoint), azureCredential);
+ var documentAnalysisClient = new DocumentAnalysisClient(new Uri(azureSearchEndpoint), azureCredential);
+ var blobServiceClient = new BlobServiceClient(new Uri(blobEndpoint), azureCredential);
+
+ var service = new AzureSearchEmbedService(
+ openAIClient: openAIClient,
+ embeddingModelName: embeddingDeployment,
+ searchClient: searchClient,
+ searchIndexName: indexName,
+ searchIndexClient: searchIndexClient,
+ documentAnalysisClient: documentAnalysisClient,
+ corpusContainerClient: blobServiceClient.GetBlobContainerClient(blobContainer),
+ computerVisionService: null,
+ includeImageEmbeddingsField: false,
+ logger: null,
+ maxTokensPerSection: 500);
+ List sections = [];
+ string testContent = "".PadRight(1000, ' ');
+ IReadOnlyList pageMap =
+ [
+ new(0, 0, testContent)
+ ];
+ await foreach (var section in service.CreateSectionsAsync(pageMap, "test-blob"))
+ {
+ sections.Add(section);
+ }
+ sections.Should().HaveCount(1);
+ sections[0].Content.Should().Be(testContent);
+ }
}