Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add streaming ability in chat page using default SDK #387

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 28 additions & 5 deletions app/SharedWebComponents/Pages/Chat.razor.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft. All rights reserved.

namespace SharedWebComponents.Pages;
using System.Text;

public sealed partial class Chat
{
Expand Down Expand Up @@ -42,17 +43,39 @@ private async Task OnAskClickedAsync()
try
{
var history = _questionAndAnswerMap
.Where(x => x.Value?.Choices is { Length: > 0})
.SelectMany(x => new ChatMessage[] { new ChatMessage("user", x.Key.Question), new ChatMessage("assistant", x.Value!.Choices[0].Message.Content) })
.Where(x => x.Value?.Choices is { Length: > 0 })
.SelectMany(x => new ChatMessage[] {
new ChatMessage("user", x.Key.Question),
new ChatMessage("assistant", x.Value!.Choices[0].Message.Content)
})
.ToList();

history.Add(new ChatMessage("user", _userQuestion));

var request = new ChatRequest([.. history], Settings.Overrides);
var result = await ApiClient.ChatConversationAsync(request);

_questionAndAnswerMap[_currentQuestion] = result.Response;
if (result.IsSuccessful)
try
{
var responseStream = await ApiClient.PostStreamingRequestAsync(request, "api/chat/stream");

await foreach (var response in responseStream)
{
_questionAndAnswerMap[_currentQuestion] = new ChatAppResponseOrError(
response.Choices,
null);

StateHasChanged();
await Task.Delay(1);
}
}
catch (Exception ex)
{
_questionAndAnswerMap[_currentQuestion] = new ChatAppResponseOrError(
Array.Empty<ResponseChoice>(),
ex.Message);
}

if (_questionAndAnswerMap[_currentQuestion]?.Error is null)
{
_userQuestion = "";
_currentQuestion = default;
Expand Down
43 changes: 14 additions & 29 deletions app/SharedWebComponents/Services/ApiClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -90,47 +90,32 @@ public async IAsyncEnumerable<DocumentResponse> GetDocumentsAsync(
}
}

public Task<AnswerResult<ChatRequest>> ChatConversationAsync(ChatRequest request) => PostRequestAsync(request, "api/chat");

private async Task<AnswerResult<TRequest>> PostRequestAsync<TRequest>(
public async Task<IAsyncEnumerable<ChatAppResponse>> PostStreamingRequestAsync<TRequest>(
TRequest request, string apiRoute) where TRequest : ApproachRequest
{
var result = new AnswerResult<TRequest>(
IsSuccessful: false,
Response: null,
Approach: request.Approach,
Request: request);

var json = JsonSerializer.Serialize(
request,
SerializerOptions.Default);

using var body = new StringContent(
using var content = new StringContent(
json, Encoding.UTF8, "application/json");

var response = await httpClient.PostAsync(apiRoute, body);
// Use both HttpCompletionOption and CancellationToken
var response = await httpClient.PostAsync(
apiRoute,
content,
CancellationToken.None);

if (response.IsSuccessStatusCode)
{
var answer = await response.Content.ReadFromJsonAsync<ChatAppResponseOrError>();
return result with
{
IsSuccessful = answer is not null,
Response = answer,
};
var stream = await response.Content.ReadAsStreamAsync();
var nullableResponses = JsonSerializer.DeserializeAsyncEnumerable<ChatAppResponse>(
stream,
new JsonSerializerOptions { PropertyNameCaseInsensitive = true });

return nullableResponses.Where(r => r != null)!;
}
else
{
var errorTitle = $"HTTP {(int)response.StatusCode} : {response.ReasonPhrase ?? "☹️ Unknown error..."}";
var answer = new ChatAppResponseOrError(
Array.Empty<ResponseChoice>(),
errorTitle);

return result with
{
IsSuccessful = false,
Response = answer
};
}
throw new HttpRequestException($"HTTP {(int)response.StatusCode} : {response.ReasonPhrase ?? "Unknown error"}");
}
}
21 changes: 12 additions & 9 deletions app/backend/Extensions/WebApplicationExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ internal static WebApplication MapApi(this WebApplication app)
api.MapPost("openai/chat", OnPostChatPromptAsync);

// Long-form chat w/ contextual history endpoint
api.MapPost("chat", OnPostChatAsync);
api.MapPost("chat/stream", OnPostChatStreamingAsync);

// Upload a document
api.MapPost("documents", OnPostDocumentAsync);
Expand Down Expand Up @@ -70,20 +70,23 @@ You will always reply with a Markdown formatted response.
}
}

private static async Task<IResult> OnPostChatAsync(
private static async IAsyncEnumerable<ChatAppResponse> OnPostChatStreamingAsync(
ChatRequest request,
ReadRetrieveReadChatService chatService,
CancellationToken cancellationToken)
[EnumeratorCancellation] CancellationToken cancellationToken)
{
if (request is { History.Length: > 0 })
if (request is not { History.Length: > 0 })
{
var response = await chatService.ReplyAsync(
request.History, request.Overrides, cancellationToken);

return TypedResults.Ok(response);
yield break;
}

return Results.BadRequest();
await foreach (var response in chatService.ReplyStreamingAsync(
request.History,
request.Overrides,
cancellationToken))
{
yield return response;
}
}

private static async Task<IResult> OnPostDocumentAsync(
Expand Down
141 changes: 91 additions & 50 deletions app/backend/Services/ReadRetrieveReadChatService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel.Connectors.OpenAI;
using Microsoft.SemanticKernel.Embeddings;
using System.Text;

namespace MinimalApi.Services;
#pragma warning disable SKEXP0011 // Mark members as static
Expand Down Expand Up @@ -56,10 +57,10 @@ public ReadRetrieveReadChatService(
_tokenCredential = tokenCredential;
}

public async Task<ChatAppResponse> ReplyAsync(
public async IAsyncEnumerable<ChatAppResponse> ReplyStreamingAsync(
ChatMessage[] history,
RequestOverrides? overrides,
CancellationToken cancellationToken = default)
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var top = overrides?.Top ?? 3;
var useSemanticCaptions = overrides?.SemanticCaptions ?? false;
Expand All @@ -71,9 +72,8 @@ public async Task<ChatAppResponse> ReplyAsync(
float[]? embeddings = null;
var question = history.LastOrDefault(m => m.IsUser)?.Content is { } userQuestion
? userQuestion
: throw new InvalidOperationException("Use question is null");
: throw new InvalidOperationException("User question is null");

string[]? followUpQuestionList = null;
if (overrides?.RetrievalMode != RetrievalMode.Text && embedding is not null)
{
embeddings = (await embedding.GenerateEmbeddingAsync(question, cancellationToken: cancellationToken)).ToArray();
Expand All @@ -92,11 +92,20 @@ standard plan AND dental AND employee benefit.
");

getQueryChat.AddUserMessage(question);
var result = await chat.GetChatMessageContentAsync(
var queryBuilder = new StringBuilder();

await foreach (var content in chat.GetStreamingChatMessageContentsAsync(
getQueryChat,
cancellationToken: cancellationToken);
kernel: _kernel,
cancellationToken: cancellationToken))
{
if (content.Content is { Length: > 0 })
{
queryBuilder.Append(content.Content);
}
}

query = result.Content ?? throw new InvalidOperationException("Failed to get search query");
query = queryBuilder.ToString() ?? throw new InvalidOperationException("Failed to get search query");
}

// step 2
Expand All @@ -110,7 +119,7 @@ standard plan AND dental AND employee benefit.
}
else
{
documentContents = string.Join("\r", documentContentList.Select(x =>$"{x.Title}:{x.Content}"));
documentContents = string.Join("\r", documentContentList.Select(x => $"{x.Title}:{x.Content}"));
}

// step 2.5
Expand Down Expand Up @@ -140,7 +149,7 @@ standard plan AND dental AND employee benefit.
}
}


if (images != null)
{
var prompt = @$"## Source ##
Expand Down Expand Up @@ -185,63 +194,95 @@ You answer needs to be a json object with the following format.
StopSequences = [],
};

var streamingResponse = new StringBuilder();
var documentContext = new ResponseContext(
DataPointsContent: documentContentList.Select(x => new SupportingContentRecord(x.Title, x.Content)).ToArray(),
DataPointsImages: images?.Select(x => new SupportingImageRecord(x.Title, x.Url)).ToArray(),
FollowupQuestions: Array.Empty<string>(), // Will be populated after full response
Thoughts: Array.Empty<Thoughts>()); // Will be populated after full response

// get answer
var answer = await chat.GetChatMessageContentAsync(
answerChat,
promptExecutingSetting,
cancellationToken: cancellationToken);
var answerJson = answer.Content ?? throw new InvalidOperationException("Failed to get search query");
var answerObject = JsonSerializer.Deserialize<JsonElement>(answerJson);
var ans = answerObject.GetProperty("answer").GetString() ?? throw new InvalidOperationException("Failed to get answer");
var thoughts = answerObject.GetProperty("thoughts").GetString() ?? throw new InvalidOperationException("Failed to get thoughts");
await foreach (var content in chat.GetStreamingChatMessageContentsAsync(
answerChat,
executionSettings: promptExecutingSetting,
kernel: _kernel,
cancellationToken: cancellationToken))
{
if (content.Content is { Length: > 0 })
{
streamingResponse.Append(content.Content);
var responseMessage = new ResponseMessage("assistant", streamingResponse.ToString());
var choice = new ResponseChoice(
Index: 0,
Message: responseMessage,
Context: documentContext,
CitationBaseUrl: _configuration.ToCitationBaseUrl());


yield return new ChatAppResponse(new[] { choice });
}
}

// After streaming completes, parse the final answer
var answerJson = streamingResponse.ToString();
var finalAnswerObject = JsonSerializer.Deserialize<JsonElement>(answerJson);
var ans = finalAnswerObject.GetProperty("answer").GetString() ?? throw new InvalidOperationException("Failed to get answer");
var finalThoughts = finalAnswerObject.GetProperty("thoughts").GetString() ?? throw new InvalidOperationException("Failed to get thoughts");

// Create response context that will be used throughout
var responseContext = new ResponseContext(
DataPointsContent: documentContentList.Select(x => new SupportingContentRecord(x.Title, x.Content)).ToArray(),
DataPointsImages: images?.Select(x => new SupportingImageRecord(x.Title, x.Url)).ToArray(),
FollowupQuestions: Array.Empty<string>(),
Thoughts: new[] { new Thoughts("Thoughts", finalThoughts) });

// step 4
// add follow up questions if requested
if (overrides?.SuggestFollowupQuestions is true)
{
var followUpQuestionChat = new ChatHistory(@"You are a helpful AI assistant");
followUpQuestionChat.AddUserMessage($@"Generate three follow-up question based on the answer you just generated.
followUpQuestionChat.AddUserMessage($@"Generate three follow-up questions based on the answer you just generated.
# Answer
{ans}

# Format of the response
Return the follow-up question as a json string list. Don't put your answer between ```json and ```, return the json string directly.
e.g.
[
""What is the deductible?"",
""What is the co-pay?"",
""What is the out-of-pocket maximum?""
]");
Generate three questions, one per line. Do not include any JSON formatting or other text.
For example:
What is the deductible?
What is the co-pay?
What is the out-of-pocket maximum?");

var followUpQuestions = await chat.GetChatMessageContentAsync(
var followUpQuestions = new List<string>();
var followUpBuilder = new StringBuilder();
await foreach (var content in chat.GetStreamingChatMessageContentsAsync(
followUpQuestionChat,
promptExecutingSetting,
cancellationToken: cancellationToken);

var followUpQuestionsJson = followUpQuestions.Content ?? throw new InvalidOperationException("Failed to get search query");
var followUpQuestionsObject = JsonSerializer.Deserialize<JsonElement>(followUpQuestionsJson);
var followUpQuestionsList = followUpQuestionsObject.EnumerateArray().Select(x => x.GetString()!).ToList();
foreach (var followUpQuestion in followUpQuestionsList)
executionSettings: promptExecutingSetting,
kernel: _kernel,
cancellationToken: cancellationToken))
{
ans += $" <<{followUpQuestion}>> ";
}
if (content.Content is { Length: > 0 })
{
followUpBuilder.Append(content.Content);
var questions = followUpBuilder.ToString().Split('\n', StringSplitOptions.RemoveEmptyEntries);

var answerWithQuestions = ans;
foreach (var followUpQuestion in questions)
{
answerWithQuestions += $" <<{followUpQuestion.Trim()}>> ";
}

followUpQuestionList = followUpQuestionsList.ToArray();
}
var responseMessage = new ResponseMessage("assistant", answerWithQuestions);
var updatedContext = responseContext with { FollowupQuestions = questions };

var responseMessage = new ResponseMessage("assistant", ans);
var responseContext = new ResponseContext(
DataPointsContent: documentContentList.Select(x => new SupportingContentRecord(x.Title, x.Content)).ToArray(),
DataPointsImages: images?.Select(x => new SupportingImageRecord(x.Title, x.Url)).ToArray(),
FollowupQuestions: followUpQuestionList ?? Array.Empty<string>(),
Thoughts: new[] { new Thoughts("Thoughts", thoughts) });
var choice = new ResponseChoice(
Index: 0,
Message: responseMessage,
Context: updatedContext,
CitationBaseUrl: _configuration.ToCitationBaseUrl());

var choice = new ResponseChoice(
Index: 0,
Message: responseMessage,
Context: responseContext,
CitationBaseUrl: _configuration.ToCitationBaseUrl());

return new ChatAppResponse(new[] { choice });
yield return new ChatAppResponse(new[] { choice });
}
}
}
}
}
Loading