Skip to content

Commit

Permalink
Merge pull request #47 from milderhc/add-sk-chat
Browse files Browse the repository at this point in the history
Add SK chat support
  • Loading branch information
dantelmomsft authored Nov 8, 2023
2 parents 7537fdd + 659b508 commit ae612e9
Show file tree
Hide file tree
Showing 15 changed files with 716 additions and 129 deletions.
16 changes: 9 additions & 7 deletions README.md

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
import org.slf4j.LoggerFactory;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.context.annotation.Bean;
import org.springframework.web.servlet.config.annotation.CorsRegistry;
import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;

@SpringBootApplication
public class Application {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
// Copyright (c) Microsoft. All rights reserved.
package com.microsoft.openai.samples.rag.approaches;

import com.microsoft.openai.samples.rag.ask.approaches.PlainJavaAskApproach;
import com.microsoft.openai.samples.rag.ask.approaches.semantickernel.JavaSemanticKernelChainsApproach;
import com.microsoft.openai.samples.rag.ask.approaches.semantickernel.JavaSemanticKernelPlannerApproach;
import com.microsoft.openai.samples.rag.ask.approaches.semantickernel.JavaSemanticKernelWithMemoryApproach;
import com.microsoft.openai.samples.rag.chat.approaches.PlainJavaChatApproach;
import com.microsoft.openai.samples.rag.chat.approaches.semantickernel.JavaSemanticKernelChainsChatApproach;
import com.microsoft.openai.samples.rag.chat.approaches.semantickernel.JavaSemanticKernelWithMemoryChatApproach;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.stereotype.Component;

@Component
public class RAGApproachFactorySpringBootImpl
implements RAGApproachFactory, ApplicationContextAware {
public class RAGApproachFactorySpringBootImpl implements RAGApproachFactory, ApplicationContextAware {

private static final String JAVA_OPENAI_SDK = "jos";
private static final String JAVA_SEMANTIC_KERNEL = "jsk";
Expand All @@ -29,31 +29,34 @@ public class RAGApproachFactorySpringBootImpl
@Override
public RAGApproach createApproach(String approachName, RAGType ragType, RAGOptions ragOptions) {

if (ragType.equals(RAGType.CHAT) && JAVA_OPENAI_SDK.equals(approachName)) {
return applicationContext.getBean(PlainJavaChatApproach.class);

if (ragType.equals(RAGType.CHAT)) {
if (JAVA_OPENAI_SDK.equals(approachName)) {
return applicationContext.getBean(PlainJavaChatApproach.class);
} else if (JAVA_SEMANTIC_KERNEL.equals(approachName)) {
return applicationContext.getBean(JavaSemanticKernelWithMemoryChatApproach.class);
} else if (
JAVA_SEMANTIC_KERNEL_PLANNER.equals(approachName) &&
ragOptions != null &&
ragOptions.getSemantickKernelMode() != null &&
ragOptions.getSemantickKernelMode() == SemanticKernelMode.chains) {
return applicationContext.getBean(JavaSemanticKernelChainsChatApproach.class);
}
} else if (ragType.equals(RAGType.ASK)) {
if (JAVA_OPENAI_SDK.equals(approachName))
return applicationContext.getBean(PlainJavaAskApproach.class);
else if (JAVA_SEMANTIC_KERNEL.equals(approachName))
return applicationContext.getBean(JavaSemanticKernelWithMemoryApproach.class);
else if (JAVA_SEMANTIC_KERNEL_PLANNER.equals(approachName)
&& ragOptions.getSemantickKernelMode() != null
&& ragOptions.getSemantickKernelMode() == SemanticKernelMode.planner)
else if (JAVA_SEMANTIC_KERNEL_PLANNER.equals(approachName) && ragOptions.getSemantickKernelMode() != null && ragOptions.getSemantickKernelMode() == SemanticKernelMode.planner)
return applicationContext.getBean(JavaSemanticKernelPlannerApproach.class);
else if (JAVA_SEMANTIC_KERNEL_PLANNER.equals(approachName)
&& ragOptions != null
&& ragOptions.getSemantickKernelMode() != null
&& ragOptions.getSemantickKernelMode() == SemanticKernelMode.chains)
else if (JAVA_SEMANTIC_KERNEL_PLANNER.equals(approachName) && ragOptions != null && ragOptions.getSemantickKernelMode() != null && ragOptions.getSemantickKernelMode() == SemanticKernelMode.chains)
return applicationContext.getBean(JavaSemanticKernelChainsApproach.class);
}
// if this point is reached then the combination of approach and rag type is not supported
throw new IllegalArgumentException(
"Invalid combination for approach[%s] and rag type[%s]: "
.formatted(approachName, ragType));
//if this point is reached then the combination of approach and rag type is not supported
throw new IllegalArgumentException("Invalid combination for approach[%s] and rag type[%s]: ".formatted(approachName, ragType));
}

public void setApplicationContext(ApplicationContext applicationContext) {
this.applicationContext = applicationContext;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import com.microsoft.openai.samples.rag.approaches.RAGApproach;
import com.microsoft.openai.samples.rag.approaches.RAGOptions;
import com.microsoft.openai.samples.rag.approaches.RAGResponse;
import com.microsoft.openai.samples.rag.retrieval.semantickernel.CognitiveSearchPlugin;
import com.microsoft.openai.samples.rag.proxy.CognitiveSearchProxy;
import com.microsoft.openai.samples.rag.proxy.OpenAIProxy;
import com.microsoft.semantickernel.Kernel;
Expand All @@ -25,7 +26,7 @@
/**
* Use Java Semantic Kernel framework with semantic and native functions chaining. It uses an
* imperative style for AI orchestration through semantic kernel functions chaining.
* InformationFinder.Search native function and RAG.AnswerQuestion semantic function are called
* InformationFinder.SearchFromQuestion native function and RAG.AnswerQuestion semantic function are called
* sequentially. Several cognitive search retrieval options are available: Text, Vector, Hybrid.
*/
@Component
Expand Down Expand Up @@ -74,7 +75,7 @@ public RAGResponse run(String question, RAGOptions options) {
question,
semanticKernel
.getSkill("InformationFinder")
.getFunction("Search", null))
.getFunction("SearchFromQuestion", null))
.block();

var sources = formSourcesList(searchContext.getResult());
Expand Down Expand Up @@ -135,9 +136,9 @@ private List<ContentSource> formSourcesList(String result) {

/**
* Build semantic kernel context with AnswerQuestion semantic function and
* InformationFinder.Search native function. AnswerQuestion is imported from
* src/main/resources/semantickernel/Plugins. InformationFinder.Search is implemented in a
* traditional Java class method: CognitiveSearchPlugin.search
* InformationFinder.SearchFromQuestion native function. AnswerQuestion is imported from
* src/main/resources/semantickernel/Plugins. InformationFinder.SearchFromQuestion is implemented in a
* traditional Java class method: CognitiveSearchPlugin.searchFromConversation
*
* @param options
* @return
Expand All @@ -155,7 +156,6 @@ private Kernel buildSemanticKernel(RAGOptions options) {
kernel.importSkill(
new CognitiveSearchPlugin(this.cognitiveSearchProxy, this.openAIProxy, options),
"InformationFinder");

kernel.importSkillFromResources("semantickernel/Plugins", "RAG", "AnswerQuestion", null);

return kernel;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import com.microsoft.openai.samples.rag.approaches.RAGApproach;
import com.microsoft.openai.samples.rag.approaches.RAGOptions;
import com.microsoft.openai.samples.rag.approaches.RAGResponse;
import com.microsoft.openai.samples.rag.retrieval.semantickernel.CognitiveSearchPlugin;
import com.microsoft.openai.samples.rag.proxy.CognitiveSearchProxy;
import com.microsoft.openai.samples.rag.proxy.OpenAIProxy;
import com.microsoft.semantickernel.Kernel;
Expand Down Expand Up @@ -97,8 +98,8 @@ public void runStreaming(
/**
* Build semantic kernel context with AnswerQuestion semantic function and
* InformationFinder.Search native function. AnswerQuestion is imported from
* src/main/resources/semantickernel/Plugins. InformationFinder.Search is implemented in a
* traditional Java class method: CognitiveSearchPlugin.search
* src/main/resources/semantickernel/Plugins. InformationFinder.SearchFromQuestion is implemented in a
* traditional Java class method: CognitiveSearchPlugin.searchFromQuestion
*
* @param options
* @return
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
package com.microsoft.openai.samples.rag.chat.approaches.semantickernel;

import com.azure.ai.openai.OpenAIAsyncClient;
import com.microsoft.openai.samples.rag.approaches.ContentSource;
import com.microsoft.openai.samples.rag.approaches.RAGApproach;
import com.microsoft.openai.samples.rag.approaches.RAGOptions;
import com.microsoft.openai.samples.rag.approaches.RAGResponse;
import com.microsoft.openai.samples.rag.retrieval.semantickernel.CognitiveSearchPlugin;
import com.microsoft.openai.samples.rag.common.ChatGPTConversation;
import com.microsoft.openai.samples.rag.common.ChatGPTUtils;
import com.microsoft.openai.samples.rag.proxy.CognitiveSearchProxy;
import com.microsoft.openai.samples.rag.proxy.OpenAIProxy;
import com.microsoft.semantickernel.Kernel;
import com.microsoft.semantickernel.SKBuilders;
import com.microsoft.semantickernel.orchestration.ContextVariables;
import com.microsoft.semantickernel.orchestration.SKContext;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;

import java.io.OutputStream;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;

/**
* Use Java Semantic Kernel framework with semantic and native functions chaining. It uses an
* imperative style for AI orchestration through semantic kernel functions chaining.
* InformationFinder.SearchFromConversation native function and RAG.AnswerConversation semantic function are called
* sequentially. Several cognitive search retrieval options are available: Text, Vector, Hybrid.
*/
@Component
public class JavaSemanticKernelChainsChatApproach implements RAGApproach<ChatGPTConversation, RAGResponse> {
private final CognitiveSearchProxy cognitiveSearchProxy;

private final OpenAIProxy openAIProxy;

private final OpenAIAsyncClient openAIAsyncClient;

@Value("${openai.chatgpt.deployment}")
private String gptChatDeploymentModelId;

public JavaSemanticKernelChainsChatApproach(CognitiveSearchProxy cognitiveSearchProxy, OpenAIAsyncClient openAIAsyncClient, OpenAIProxy openAIProxy) {
this.cognitiveSearchProxy = cognitiveSearchProxy;
this.openAIAsyncClient = openAIAsyncClient;
this.openAIProxy = openAIProxy;
}

/**
* @param questionOrConversation
* @param options
* @return
*/
@Override
public RAGResponse run(ChatGPTConversation questionOrConversation, RAGOptions options) {
String question = ChatGPTUtils.getLastUserQuestion(questionOrConversation.getMessages());
String conversation = ChatGPTUtils.formatAsChatML(questionOrConversation.toOpenAIChatMessages());

Kernel semanticKernel = buildSemanticKernel(options);

// STEP 1: Retrieve relevant documents using the current conversation. It reuses the
// CognitiveSearchRetriever appraoch through the CognitiveSearchPlugin native function.
SKContext searchContext =
semanticKernel.runAsync(
conversation,
semanticKernel.getSkill("InformationFinder").getFunction("SearchFromConversation", null)).block();

// STEP 2: Build a SK context with the sources retrieved from the memory store and conversation
ContextVariables variables = SKBuilders.variables()
.withVariable("sources", searchContext.getResult())
.withVariable("conversation", conversation)
.withVariable("suggestions", String.valueOf(options.isSuggestFollowupQuestions()))
.withVariable("input", question)
.build();

/**
* STEP 3: Get a reference of the semantic function [AnswerConversation] of the [RAG] plugin
* (a.k.a. skill) from the SK skills registry and provide it with the pre-built context.
* Triggering Open AI to get a reply.
*/
SKContext reply = semanticKernel.runAsync(variables,
semanticKernel.getSkill("RAG").getFunction("AnswerConversation", null)).block();

return new RAGResponse.Builder()
.prompt("Prompt is managed by Semantic Kernel")
.answer(reply.getResult())
.sources(formSourcesList(searchContext.getResult()))
.sourcesAsText(searchContext.getResult())
.question(question)
.build();
}

@Override
public void runStreaming(
ChatGPTConversation questionOrConversation,
RAGOptions options,
OutputStream outputStream) {
throw new IllegalStateException("Streaming not supported for this approach");
}

private List<ContentSource> formSourcesList(String result) {
if (result == null) {
return Collections.emptyList();
}
return Arrays.stream(result
.split("\n"))
.map(source -> {
String[] split = source.split(":", 2);
if (split.length >= 2) {
var sourceName = split[0].trim();
var sourceContent = split[1].trim();
return new ContentSource(sourceName, sourceContent);
} else {
return null;
}
})
.filter(Objects::nonNull)
.collect(Collectors.toList());
}

/**
* Build semantic kernel context with AnswerConversation semantic function and
* InformationFinder.SearchFromConversation native function. AnswerConversation is imported from
* src/main/resources/semantickernel/Plugins. InformationFinder.SearchFromConversation is implemented in a
* traditional Java class method: CognitiveSearchPlugin.searchFromConversation
*
* @param options
* @return
*/
private Kernel buildSemanticKernel(RAGOptions options) {
Kernel kernel = SKBuilders.kernel()
.withDefaultAIService(SKBuilders.chatCompletion()
.withModelId(gptChatDeploymentModelId)
.withOpenAIClient(this.openAIAsyncClient)
.build())
.build();

kernel.importSkill(
new CognitiveSearchPlugin(this.cognitiveSearchProxy, this.openAIProxy, options),
"InformationFinder");
kernel.importSkillFromResources("semantickernel/Plugins", "RAG", "AnswerConversation", null);

return kernel;
}

}
Loading

0 comments on commit ae612e9

Please sign in to comment.