Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure chunked PDF documents are never bigger than 500 tokens, support CJK and fix bug with tiny documents #303

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions app/Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
<PackageVersion Include="Microsoft.Maui.Controls" Version="$(MauiVersion)" />
<PackageVersion Include="Microsoft.Maui.Controls.Compatibility" Version="$(MauiVersion)" />
<PackageVersion Include="Microsoft.ML" Version="3.0.0" />
<PackageVersion Include="Microsoft.ML.Tokenizers" Version="0.22.0-preview.24162.2" />
<PackageVersion Include="Microsoft.NET.Test.Sdk" Version="17.8.0" />
<PackageVersion Include="Microsoft.SemanticKernel" Version="1.3.0" />
<PackageVersion Include="Microsoft.VisualStudio.Azure.Containers.Tools.Targets" Version="1.19.5" />
Expand Down
117 changes: 103 additions & 14 deletions app/shared/Shared/Services/AzureSearchEmbedService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
using Azure.Storage.Blobs;
using Azure.Storage.Blobs.Models;
using Microsoft.Extensions.Logging;
using Microsoft.ML.Tokenizers;
using Shared.Models;

public sealed partial class AzureSearchEmbedService(
Expand All @@ -25,11 +26,16 @@ public sealed partial class AzureSearchEmbedService(
BlobContainerClient corpusContainerClient,
IComputerVisionService? computerVisionService = null,
bool includeImageEmbeddingsField = false,
ILogger<AzureSearchEmbedService>? logger = null) : IEmbedService
ILogger<AzureSearchEmbedService>? logger = null,
string modelEncodingName = "gpt-4",
int maxTokensPerSection = 500) : IEmbedService
{
[GeneratedRegex("[^0-9a-zA-Z_-]")]
private static partial Regex MatchInSetRegex();

private static readonly char[] s_sentenceEndings = ['.', '。', '.', '!', '?', '‼', '⁇', '⁈', '⁉'];
private const int DefaultOverlapPercent = 10;

public async Task<bool> EmbedPDFBlobAsync(Stream pdfBlobStream, string blobName)
{
try
Expand All @@ -48,7 +54,11 @@ public async Task<bool> EmbedPDFBlobAsync(Stream pdfBlobStream, string blobName)
await UploadCorpusAsync(corpusName, page.Text);
}

var sections = CreateSections(pageMap, blobName);
List<Section> sections = [];
await foreach(var section in CreateSectionsAsync(pageMap, blobName))
{
sections.Add(section);
}

var infoLoggingEnabled = logger?.IsEnabled(LogLevel.Information);
if (infoLoggingEnabled is true)
Expand Down Expand Up @@ -109,7 +119,7 @@ public async Task<bool> EmbedImageBlobAsync(
return true;
}

public async Task CreateSearchIndexAsync(string searchIndexName, CancellationToken ct = default)
public async Task CreateSearchIndexAsync(string searchIndexName, CancellationToken ct = default)
{
string vectorSearchConfigName = "my-vector-config";
string vectorSearchProfile = "my-vector-profile";
Expand Down Expand Up @@ -315,18 +325,97 @@ private async Task UploadCorpusAsync(string corpusBlobName, string text)
});
}

public IEnumerable<Section> CreateSections(
public async IAsyncEnumerable<Section> SplitSectionByTokenLengthAsync(string Id,
string Content,
string SourcePage,
string SourceFile,
string? Category = null,
Tokenizer? tokenizer = null)
{
tokenizer ??= await Tiktoken.CreateByModelNameAsync(modelEncodingName);
if (tokenizer.CountTokens(Content) <= maxTokensPerSection)
{
yield return new Section(
Id: Id,
Content: Content,
SourcePage: SourcePage,
SourceFile: SourceFile,
Category: Category);
} else
{
int start = Content.Length / 2;
int pos = 0;
int boundary = Content.Length / 3;
int splitPosition = -1;

while (start - pos > boundary)
{
if (s_sentenceEndings.Contains(Content[start - pos]))
{
splitPosition = start - pos;
break;
}
else if (s_sentenceEndings.Contains(Content[start + pos]))
{
splitPosition = start + pos;
break;
}
else
{
pos += 1;
}
}

string firstHalf, secondHalf;

if (splitPosition > 0)
{
firstHalf = Content[..(splitPosition + 1)];
secondHalf = Content[(splitPosition + 1)..];
}
else
{
// Split page in half and call function again
// Overlap first and second halves by DEFAULT_OVERLAP_PERCENT%
int middle = Content.Length / 2;
int overlapChars = (int)(Content.Length * (DefaultOverlapPercent / 100.0));
firstHalf = Content[..(middle + overlapChars)];
secondHalf = Content[(middle - overlapChars)..];
}

await foreach(var section in SplitSectionByTokenLengthAsync(Id, firstHalf, SourcePage, SourceFile, Category, tokenizer))
{
yield return section;
}
await foreach (var section in SplitSectionByTokenLengthAsync(Id, secondHalf, SourcePage, SourceFile, Category, tokenizer))
{
yield return section;
}
}

}

public async IAsyncEnumerable<Section> CreateSectionsAsync(
IReadOnlyList<PageDetail> pageMap, string blobName)
{
const int MaxSectionLength = 1_000;
const int SentenceSearchLimit = 100;
const int SectionOverlap = 100;

var sentenceEndings = new[] { '.', '!', '?' };
var wordBreaks = new[] { ',', ';', ':', ' ', '(', ')', '[', ']', '{', '}', '\t', '\n' };
var wordBreaks = new[] { ',', '、', ';', ':', ' ', '(', ')', '[', ']', '{', '}', '\t', '\n' };
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lift this to a class-scoped variable to avoid reallocating the same thing multiple times with each call to CreateSectionsAsync.

var allText = string.Concat(pageMap.Select(p => p.Text));
var length = allText.Length;
var start = 0;

if (length <= MaxSectionLength)
{
await foreach (var section in SplitSectionByTokenLengthAsync(
Id: MatchInSetRegex().Replace($"{blobName}-{start}", "_").TrimStart('_'),
Content: allText,
SourcePage: BlobNameFromFilePage(blobName, FindPage(pageMap, start)),
SourceFile: blobName)) { yield return section; }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
SourceFile: blobName)) { yield return section; }
SourceFile: blobName))
{
yield return section;
}

}

var end = length;

logger?.LogInformation("Splitting '{BlobName}' into sections", blobName);
Expand All @@ -343,7 +432,7 @@ public IEnumerable<Section> CreateSections(
else
{
// Try to find the end of the sentence
while (end < length && (end - start - MaxSectionLength) < SentenceSearchLimit && !sentenceEndings.Contains(allText[end]))
while (end < length && (end - start - MaxSectionLength) < SentenceSearchLimit && !s_sentenceEndings.Contains(allText[end]))
{
if (wordBreaks.Contains(allText[end]))
{
Expand All @@ -352,7 +441,7 @@ public IEnumerable<Section> CreateSections(
end++;
}

if (end < length && !sentenceEndings.Contains(allText[end]) && lastWord > 0)
if (end < length && !s_sentenceEndings.Contains(allText[end]) && lastWord > 0)
{
end = lastWord; // Fall back to at least keeping a whole word
}
Expand All @@ -366,7 +455,7 @@ public IEnumerable<Section> CreateSections(
// Try to find the start of the sentence or at least a whole word boundary
lastWord = -1;
while (start > 0 && start > end - MaxSectionLength -
(2 * SentenceSearchLimit) && !sentenceEndings.Contains(allText[start]))
(2 * SentenceSearchLimit) && !s_sentenceEndings.Contains(allText[start]))
{
if (wordBreaks.Contains(allText[start]))
{
Expand All @@ -375,7 +464,7 @@ public IEnumerable<Section> CreateSections(
start--;
}

if (!sentenceEndings.Contains(allText[start]) && lastWord > 0)
if (!s_sentenceEndings.Contains(allText[start]) && lastWord > 0)
{
start = lastWord;
}
Expand All @@ -386,11 +475,11 @@ public IEnumerable<Section> CreateSections(

var sectionText = allText[start..end];

yield return new Section(
await foreach (var section in SplitSectionByTokenLengthAsync(
Id: MatchInSetRegex().Replace($"{blobName}-{start}", "_").TrimStart('_'),
Content: sectionText,
SourcePage: BlobNameFromFilePage(blobName, FindPage(pageMap, start)),
SourceFile: blobName);
SourceFile: blobName)) { yield return section; }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
SourceFile: blobName)) { yield return section; }
SourceFile: blobName))
{
yield return section;
}


var lastTableStart = sectionText.LastIndexOf("<table", StringComparison.Ordinal);
if (lastTableStart > 2 * SentenceSearchLimit && lastTableStart > sectionText.LastIndexOf("</table", StringComparison.Ordinal))
Expand Down Expand Up @@ -419,11 +508,11 @@ table at page {Offset} offset {Start} table start {LastTableStart}

if (start + SectionOverlap < end)
{
yield return new Section(
await foreach (var section in SplitSectionByTokenLengthAsync(
Id: MatchInSetRegex().Replace($"{blobName}-{start}", "_").TrimStart('_'),
Content: allText[start..end],
SourcePage: BlobNameFromFilePage(blobName, FindPage(pageMap, start)),
SourceFile: blobName);
SourceFile: blobName)) { yield return section; }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
SourceFile: blobName)) { yield return section; }
SourceFile: blobName))
{
yield return section;
}

}
}

Expand Down
1 change: 1 addition & 0 deletions app/shared/Shared/Shared.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
<PackageReference Include="Azure.Search.Documents" />
<PackageReference Include="Azure.Storage.Blobs" />
<PackageReference Include="Microsoft.ApplicationInsights.AspNetCore" />
<PackageReference Include="Microsoft.ML.Tokenizers" />
<PackageReference Include="PdfSharpCore" />
</ItemGroup>

Expand Down
88 changes: 88 additions & 0 deletions app/tests/MinimalApi.Tests/AzureSearchEmbedServiceTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
using Azure.Storage.Blobs;
using FluentAssertions;
using Microsoft.Extensions.Logging;
using MudBlazor.Services;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed?

using NSubstitute;
using Shared.Models;
using Xunit;

namespace MinimalApi.Tests;
public class AzureSearchEmbedServiceTest
Expand Down Expand Up @@ -312,4 +315,89 @@ public async Task EmbedImageBlobTestAsync()
await blobServiceClient.DeleteBlobContainerAsync(blobContainer);
}
}

[Fact]
public async Task EnsureTextSplitsOnTinySectionAsync()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd expect a lot more test cases. What about some content with the various word splits or sentence delimiters. Also, consider using a Theory to avoid redundant code, and parameterize all the known-inputs, and expected outputs.

{
var indexName = nameof(EnsureSearchIndexWithoutImageEmbeddingsAsync).ToLower();
var openAIEndpoint = "https://fake.openai.azure.com";
var embeddingDeployment = "gpt-4";
var azureSearchEndpoint = "https://fake-search.search.azure.com";
var blobEndpoint = "https://fake-storage.azure.com/";
var blobContainer = "test";

var azureCredential = new DefaultAzureCredential();
var openAIClient = new OpenAIClient(new Uri(openAIEndpoint), azureCredential);
var searchClient = new SearchClient(new Uri(azureSearchEndpoint), indexName, azureCredential);
var searchIndexClient = new SearchIndexClient(new Uri(azureSearchEndpoint), azureCredential);
var documentAnalysisClient = new DocumentAnalysisClient(new Uri(azureSearchEndpoint), azureCredential);
var blobServiceClient = new BlobServiceClient(new Uri(blobEndpoint), azureCredential);

var service = new AzureSearchEmbedService(
openAIClient: openAIClient,
embeddingModelName: embeddingDeployment,
searchClient: searchClient,
searchIndexName: indexName,
searchIndexClient: searchIndexClient,
documentAnalysisClient: documentAnalysisClient,
corpusContainerClient: blobServiceClient.GetBlobContainerClient(blobContainer),
computerVisionService: null,
includeImageEmbeddingsField: false,
logger: null,
maxTokensPerSection: 500);
List<Section> sections = [];
IReadOnlyList<PageDetail> pageMap =
[
new(0, 0, "this is a test")
];
await foreach (var section in service.CreateSectionsAsync(pageMap, "test-blob"))
{
sections.Add(section);
}
sections.Should().HaveCount(1);
sections[0].Content.Should().Be("this is a test");
}

[Fact]
public async Task EnsureTextSplitsOnBigSectionAsync()
{
var indexName = nameof(EnsureSearchIndexWithoutImageEmbeddingsAsync).ToLower();
var openAIEndpoint = "https://fake.openai.azure.com";
var embeddingDeployment = "embeddings";
var azureSearchEndpoint = "https://fake-search.search.azure.com";
var blobEndpoint = "https://fake-storage.azure.com/";
var blobContainer = "test";

var azureCredential = new DefaultAzureCredential();
var openAIClient = new OpenAIClient(new Uri(openAIEndpoint), azureCredential);
var searchClient = new SearchClient(new Uri(azureSearchEndpoint), indexName, azureCredential);
var searchIndexClient = new SearchIndexClient(new Uri(azureSearchEndpoint), azureCredential);
var documentAnalysisClient = new DocumentAnalysisClient(new Uri(azureSearchEndpoint), azureCredential);
var blobServiceClient = new BlobServiceClient(new Uri(blobEndpoint), azureCredential);

var service = new AzureSearchEmbedService(
openAIClient: openAIClient,
embeddingModelName: embeddingDeployment,
searchClient: searchClient,
searchIndexName: indexName,
searchIndexClient: searchIndexClient,
documentAnalysisClient: documentAnalysisClient,
corpusContainerClient: blobServiceClient.GetBlobContainerClient(blobContainer),
computerVisionService: null,
includeImageEmbeddingsField: false,
logger: null,
maxTokensPerSection: 500);
List<Section> sections = [];
string testContent = "".PadRight(1000, ' ');
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
string testContent = "".PadRight(1000, ' ');
string testContent = new(' ', 1_000);

IReadOnlyList<PageDetail> pageMap =
[
new(0, 0, testContent)
];
await foreach (var section in service.CreateSectionsAsync(pageMap, "test-blob"))
{
sections.Add(section);
}
sections.Should().HaveCount(1);
sections[0].Content.Should().Be(testContent);
}
}
Loading