-
Notifications
You must be signed in to change notification settings - Fork 458
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Ensure chunked PDF documents are never bigger than 500 tokens, support CJK and fix bug with tiny documents #303
base: main
Are you sure you want to change the base?
Changes from all commits
a7e7ff2
4ec7fcb
710d220
e734ef1
855632d
4e45bd2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -13,6 +13,7 @@ | |||||||||||
using Azure.Storage.Blobs; | ||||||||||||
using Azure.Storage.Blobs.Models; | ||||||||||||
using Microsoft.Extensions.Logging; | ||||||||||||
using Microsoft.ML.Tokenizers; | ||||||||||||
using Shared.Models; | ||||||||||||
|
||||||||||||
public sealed partial class AzureSearchEmbedService( | ||||||||||||
|
@@ -25,11 +26,16 @@ public sealed partial class AzureSearchEmbedService( | |||||||||||
BlobContainerClient corpusContainerClient, | ||||||||||||
IComputerVisionService? computerVisionService = null, | ||||||||||||
bool includeImageEmbeddingsField = false, | ||||||||||||
ILogger<AzureSearchEmbedService>? logger = null) : IEmbedService | ||||||||||||
ILogger<AzureSearchEmbedService>? logger = null, | ||||||||||||
string modelEncodingName = "gpt-4", | ||||||||||||
int maxTokensPerSection = 500) : IEmbedService | ||||||||||||
{ | ||||||||||||
[GeneratedRegex("[^0-9a-zA-Z_-]")] | ||||||||||||
private static partial Regex MatchInSetRegex(); | ||||||||||||
|
||||||||||||
private static readonly char[] s_sentenceEndings = ['.', '。', '.', '!', '?', '‼', '⁇', '⁈', '⁉']; | ||||||||||||
private const int DefaultOverlapPercent = 10; | ||||||||||||
|
||||||||||||
public async Task<bool> EmbedPDFBlobAsync(Stream pdfBlobStream, string blobName) | ||||||||||||
{ | ||||||||||||
try | ||||||||||||
|
@@ -48,7 +54,11 @@ public async Task<bool> EmbedPDFBlobAsync(Stream pdfBlobStream, string blobName) | |||||||||||
await UploadCorpusAsync(corpusName, page.Text); | ||||||||||||
} | ||||||||||||
|
||||||||||||
var sections = CreateSections(pageMap, blobName); | ||||||||||||
List<Section> sections = []; | ||||||||||||
await foreach(var section in CreateSectionsAsync(pageMap, blobName)) | ||||||||||||
{ | ||||||||||||
sections.Add(section); | ||||||||||||
} | ||||||||||||
|
||||||||||||
var infoLoggingEnabled = logger?.IsEnabled(LogLevel.Information); | ||||||||||||
if (infoLoggingEnabled is true) | ||||||||||||
|
@@ -109,7 +119,7 @@ public async Task<bool> EmbedImageBlobAsync( | |||||||||||
return true; | ||||||||||||
} | ||||||||||||
|
||||||||||||
public async Task CreateSearchIndexAsync(string searchIndexName, CancellationToken ct = default) | ||||||||||||
public async Task CreateSearchIndexAsync(string searchIndexName, CancellationToken ct = default) | ||||||||||||
{ | ||||||||||||
string vectorSearchConfigName = "my-vector-config"; | ||||||||||||
string vectorSearchProfile = "my-vector-profile"; | ||||||||||||
|
@@ -315,18 +325,97 @@ private async Task UploadCorpusAsync(string corpusBlobName, string text) | |||||||||||
}); | ||||||||||||
} | ||||||||||||
|
||||||||||||
public IEnumerable<Section> CreateSections( | ||||||||||||
public async IAsyncEnumerable<Section> SplitSectionByTokenLengthAsync(string Id, | ||||||||||||
string Content, | ||||||||||||
string SourcePage, | ||||||||||||
string SourceFile, | ||||||||||||
string? Category = null, | ||||||||||||
Tokenizer? tokenizer = null) | ||||||||||||
{ | ||||||||||||
tokenizer ??= await Tiktoken.CreateByModelNameAsync(modelEncodingName); | ||||||||||||
if (tokenizer.CountTokens(Content) <= maxTokensPerSection) | ||||||||||||
{ | ||||||||||||
yield return new Section( | ||||||||||||
Id: Id, | ||||||||||||
Content: Content, | ||||||||||||
SourcePage: SourcePage, | ||||||||||||
SourceFile: SourceFile, | ||||||||||||
Category: Category); | ||||||||||||
} else | ||||||||||||
{ | ||||||||||||
int start = Content.Length / 2; | ||||||||||||
int pos = 0; | ||||||||||||
int boundary = Content.Length / 3; | ||||||||||||
int splitPosition = -1; | ||||||||||||
|
||||||||||||
while (start - pos > boundary) | ||||||||||||
{ | ||||||||||||
if (s_sentenceEndings.Contains(Content[start - pos])) | ||||||||||||
{ | ||||||||||||
splitPosition = start - pos; | ||||||||||||
break; | ||||||||||||
} | ||||||||||||
else if (s_sentenceEndings.Contains(Content[start + pos])) | ||||||||||||
{ | ||||||||||||
splitPosition = start + pos; | ||||||||||||
break; | ||||||||||||
} | ||||||||||||
else | ||||||||||||
{ | ||||||||||||
pos += 1; | ||||||||||||
} | ||||||||||||
} | ||||||||||||
|
||||||||||||
string firstHalf, secondHalf; | ||||||||||||
|
||||||||||||
if (splitPosition > 0) | ||||||||||||
{ | ||||||||||||
firstHalf = Content[..(splitPosition + 1)]; | ||||||||||||
secondHalf = Content[(splitPosition + 1)..]; | ||||||||||||
} | ||||||||||||
else | ||||||||||||
{ | ||||||||||||
// Split page in half and call function again | ||||||||||||
// Overlap first and second halves by DEFAULT_OVERLAP_PERCENT% | ||||||||||||
int middle = Content.Length / 2; | ||||||||||||
int overlapChars = (int)(Content.Length * (DefaultOverlapPercent / 100.0)); | ||||||||||||
firstHalf = Content[..(middle + overlapChars)]; | ||||||||||||
secondHalf = Content[(middle - overlapChars)..]; | ||||||||||||
} | ||||||||||||
|
||||||||||||
await foreach(var section in SplitSectionByTokenLengthAsync(Id, firstHalf, SourcePage, SourceFile, Category, tokenizer)) | ||||||||||||
{ | ||||||||||||
yield return section; | ||||||||||||
} | ||||||||||||
await foreach (var section in SplitSectionByTokenLengthAsync(Id, secondHalf, SourcePage, SourceFile, Category, tokenizer)) | ||||||||||||
{ | ||||||||||||
yield return section; | ||||||||||||
} | ||||||||||||
} | ||||||||||||
|
||||||||||||
} | ||||||||||||
|
||||||||||||
public async IAsyncEnumerable<Section> CreateSectionsAsync( | ||||||||||||
IReadOnlyList<PageDetail> pageMap, string blobName) | ||||||||||||
{ | ||||||||||||
const int MaxSectionLength = 1_000; | ||||||||||||
const int SentenceSearchLimit = 100; | ||||||||||||
const int SectionOverlap = 100; | ||||||||||||
|
||||||||||||
var sentenceEndings = new[] { '.', '!', '?' }; | ||||||||||||
var wordBreaks = new[] { ',', ';', ':', ' ', '(', ')', '[', ']', '{', '}', '\t', '\n' }; | ||||||||||||
var wordBreaks = new[] { ',', '、', ';', ':', ' ', '(', ')', '[', ']', '{', '}', '\t', '\n' }; | ||||||||||||
var allText = string.Concat(pageMap.Select(p => p.Text)); | ||||||||||||
var length = allText.Length; | ||||||||||||
var start = 0; | ||||||||||||
|
||||||||||||
if (length <= MaxSectionLength) | ||||||||||||
{ | ||||||||||||
await foreach (var section in SplitSectionByTokenLengthAsync( | ||||||||||||
Id: MatchInSetRegex().Replace($"{blobName}-{start}", "_").TrimStart('_'), | ||||||||||||
Content: allText, | ||||||||||||
SourcePage: BlobNameFromFilePage(blobName, FindPage(pageMap, start)), | ||||||||||||
SourceFile: blobName)) { yield return section; } | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||
} | ||||||||||||
|
||||||||||||
var end = length; | ||||||||||||
|
||||||||||||
logger?.LogInformation("Splitting '{BlobName}' into sections", blobName); | ||||||||||||
|
@@ -343,7 +432,7 @@ public IEnumerable<Section> CreateSections( | |||||||||||
else | ||||||||||||
{ | ||||||||||||
// Try to find the end of the sentence | ||||||||||||
while (end < length && (end - start - MaxSectionLength) < SentenceSearchLimit && !sentenceEndings.Contains(allText[end])) | ||||||||||||
while (end < length && (end - start - MaxSectionLength) < SentenceSearchLimit && !s_sentenceEndings.Contains(allText[end])) | ||||||||||||
{ | ||||||||||||
if (wordBreaks.Contains(allText[end])) | ||||||||||||
{ | ||||||||||||
|
@@ -352,7 +441,7 @@ public IEnumerable<Section> CreateSections( | |||||||||||
end++; | ||||||||||||
} | ||||||||||||
|
||||||||||||
if (end < length && !sentenceEndings.Contains(allText[end]) && lastWord > 0) | ||||||||||||
if (end < length && !s_sentenceEndings.Contains(allText[end]) && lastWord > 0) | ||||||||||||
{ | ||||||||||||
end = lastWord; // Fall back to at least keeping a whole word | ||||||||||||
} | ||||||||||||
|
@@ -366,7 +455,7 @@ public IEnumerable<Section> CreateSections( | |||||||||||
// Try to find the start of the sentence or at least a whole word boundary | ||||||||||||
lastWord = -1; | ||||||||||||
while (start > 0 && start > end - MaxSectionLength - | ||||||||||||
(2 * SentenceSearchLimit) && !sentenceEndings.Contains(allText[start])) | ||||||||||||
(2 * SentenceSearchLimit) && !s_sentenceEndings.Contains(allText[start])) | ||||||||||||
{ | ||||||||||||
if (wordBreaks.Contains(allText[start])) | ||||||||||||
{ | ||||||||||||
|
@@ -375,7 +464,7 @@ public IEnumerable<Section> CreateSections( | |||||||||||
start--; | ||||||||||||
} | ||||||||||||
|
||||||||||||
if (!sentenceEndings.Contains(allText[start]) && lastWord > 0) | ||||||||||||
if (!s_sentenceEndings.Contains(allText[start]) && lastWord > 0) | ||||||||||||
{ | ||||||||||||
start = lastWord; | ||||||||||||
} | ||||||||||||
|
@@ -386,11 +475,11 @@ public IEnumerable<Section> CreateSections( | |||||||||||
|
||||||||||||
var sectionText = allText[start..end]; | ||||||||||||
|
||||||||||||
yield return new Section( | ||||||||||||
await foreach (var section in SplitSectionByTokenLengthAsync( | ||||||||||||
Id: MatchInSetRegex().Replace($"{blobName}-{start}", "_").TrimStart('_'), | ||||||||||||
Content: sectionText, | ||||||||||||
SourcePage: BlobNameFromFilePage(blobName, FindPage(pageMap, start)), | ||||||||||||
SourceFile: blobName); | ||||||||||||
SourceFile: blobName)) { yield return section; } | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||
|
||||||||||||
var lastTableStart = sectionText.LastIndexOf("<table", StringComparison.Ordinal); | ||||||||||||
if (lastTableStart > 2 * SentenceSearchLimit && lastTableStart > sectionText.LastIndexOf("</table", StringComparison.Ordinal)) | ||||||||||||
|
@@ -419,11 +508,11 @@ table at page {Offset} offset {Start} table start {LastTableStart} | |||||||||||
|
||||||||||||
if (start + SectionOverlap < end) | ||||||||||||
{ | ||||||||||||
yield return new Section( | ||||||||||||
await foreach (var section in SplitSectionByTokenLengthAsync( | ||||||||||||
Id: MatchInSetRegex().Replace($"{blobName}-{start}", "_").TrimStart('_'), | ||||||||||||
Content: allText[start..end], | ||||||||||||
SourcePage: BlobNameFromFilePage(blobName, FindPage(pageMap, start)), | ||||||||||||
SourceFile: blobName); | ||||||||||||
SourceFile: blobName)) { yield return section; } | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||
} | ||||||||||||
} | ||||||||||||
|
||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -16,7 +16,10 @@ | |||||
using Azure.Storage.Blobs; | ||||||
using FluentAssertions; | ||||||
using Microsoft.Extensions.Logging; | ||||||
using MudBlazor.Services; | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this needed? |
||||||
using NSubstitute; | ||||||
using Shared.Models; | ||||||
using Xunit; | ||||||
|
||||||
namespace MinimalApi.Tests; | ||||||
public class AzureSearchEmbedServiceTest | ||||||
|
@@ -312,4 +315,89 @@ public async Task EmbedImageBlobTestAsync() | |||||
await blobServiceClient.DeleteBlobContainerAsync(blobContainer); | ||||||
} | ||||||
} | ||||||
|
||||||
[Fact] | ||||||
public async Task EnsureTextSplitsOnTinySectionAsync() | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd expect a lot more test cases. What about some content with the various word splits or sentence delimiters. Also, consider using a |
||||||
{ | ||||||
var indexName = nameof(EnsureSearchIndexWithoutImageEmbeddingsAsync).ToLower(); | ||||||
var openAIEndpoint = "https://fake.openai.azure.com"; | ||||||
var embeddingDeployment = "gpt-4"; | ||||||
var azureSearchEndpoint = "https://fake-search.search.azure.com"; | ||||||
var blobEndpoint = "https://fake-storage.azure.com/"; | ||||||
var blobContainer = "test"; | ||||||
|
||||||
var azureCredential = new DefaultAzureCredential(); | ||||||
var openAIClient = new OpenAIClient(new Uri(openAIEndpoint), azureCredential); | ||||||
var searchClient = new SearchClient(new Uri(azureSearchEndpoint), indexName, azureCredential); | ||||||
var searchIndexClient = new SearchIndexClient(new Uri(azureSearchEndpoint), azureCredential); | ||||||
var documentAnalysisClient = new DocumentAnalysisClient(new Uri(azureSearchEndpoint), azureCredential); | ||||||
var blobServiceClient = new BlobServiceClient(new Uri(blobEndpoint), azureCredential); | ||||||
|
||||||
var service = new AzureSearchEmbedService( | ||||||
openAIClient: openAIClient, | ||||||
embeddingModelName: embeddingDeployment, | ||||||
searchClient: searchClient, | ||||||
searchIndexName: indexName, | ||||||
searchIndexClient: searchIndexClient, | ||||||
documentAnalysisClient: documentAnalysisClient, | ||||||
corpusContainerClient: blobServiceClient.GetBlobContainerClient(blobContainer), | ||||||
computerVisionService: null, | ||||||
includeImageEmbeddingsField: false, | ||||||
logger: null, | ||||||
maxTokensPerSection: 500); | ||||||
List<Section> sections = []; | ||||||
IReadOnlyList<PageDetail> pageMap = | ||||||
[ | ||||||
new(0, 0, "this is a test") | ||||||
]; | ||||||
await foreach (var section in service.CreateSectionsAsync(pageMap, "test-blob")) | ||||||
{ | ||||||
sections.Add(section); | ||||||
} | ||||||
sections.Should().HaveCount(1); | ||||||
sections[0].Content.Should().Be("this is a test"); | ||||||
} | ||||||
|
||||||
[Fact] | ||||||
public async Task EnsureTextSplitsOnBigSectionAsync() | ||||||
{ | ||||||
var indexName = nameof(EnsureSearchIndexWithoutImageEmbeddingsAsync).ToLower(); | ||||||
var openAIEndpoint = "https://fake.openai.azure.com"; | ||||||
var embeddingDeployment = "embeddings"; | ||||||
var azureSearchEndpoint = "https://fake-search.search.azure.com"; | ||||||
var blobEndpoint = "https://fake-storage.azure.com/"; | ||||||
var blobContainer = "test"; | ||||||
|
||||||
var azureCredential = new DefaultAzureCredential(); | ||||||
var openAIClient = new OpenAIClient(new Uri(openAIEndpoint), azureCredential); | ||||||
var searchClient = new SearchClient(new Uri(azureSearchEndpoint), indexName, azureCredential); | ||||||
var searchIndexClient = new SearchIndexClient(new Uri(azureSearchEndpoint), azureCredential); | ||||||
var documentAnalysisClient = new DocumentAnalysisClient(new Uri(azureSearchEndpoint), azureCredential); | ||||||
var blobServiceClient = new BlobServiceClient(new Uri(blobEndpoint), azureCredential); | ||||||
|
||||||
var service = new AzureSearchEmbedService( | ||||||
openAIClient: openAIClient, | ||||||
embeddingModelName: embeddingDeployment, | ||||||
searchClient: searchClient, | ||||||
searchIndexName: indexName, | ||||||
searchIndexClient: searchIndexClient, | ||||||
documentAnalysisClient: documentAnalysisClient, | ||||||
corpusContainerClient: blobServiceClient.GetBlobContainerClient(blobContainer), | ||||||
computerVisionService: null, | ||||||
includeImageEmbeddingsField: false, | ||||||
logger: null, | ||||||
maxTokensPerSection: 500); | ||||||
List<Section> sections = []; | ||||||
string testContent = "".PadRight(1000, ' '); | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
IReadOnlyList<PageDetail> pageMap = | ||||||
[ | ||||||
new(0, 0, testContent) | ||||||
]; | ||||||
await foreach (var section in service.CreateSectionsAsync(pageMap, "test-blob")) | ||||||
{ | ||||||
sections.Add(section); | ||||||
} | ||||||
sections.Should().HaveCount(1); | ||||||
sections[0].Content.Should().Be(testContent); | ||||||
} | ||||||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lift this to a class-scoped variable to avoid reallocating the same thing multiple times with each call to
CreateSectionsAsync
.