-
Notifications
You must be signed in to change notification settings - Fork 721
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add zero shot NLU using LLMs * add docker-compose for ollama * switch between default NLU and LLM pipelines * add synonym replacer NLU component * refactor docker-compose * update docs
- Loading branch information
1 parent
c23c561
commit cf06697
Showing
30 changed files
with
663 additions
and
179 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,40 @@ | ||
from pydantic import BaseModel, Field, ConfigDict | ||
from typing import Dict, Any | ||
from pydantic import BaseModel, Field | ||
from typing import Optional | ||
from app.database import ObjectIdField | ||
from datetime import datetime | ||
|
||
|
||
class TraditionalNLUSettings(BaseModel): | ||
"""Settings for traditional ML-based NLU pipeline""" | ||
|
||
intent_detection_threshold: float = 0.75 | ||
entity_detection_threshold: float = 0.65 | ||
use_spacy: bool = True | ||
|
||
|
||
class LLMSettings(BaseModel): | ||
"""Settings for LLM-based NLU pipeline""" | ||
|
||
base_url: str = "http://127.0.0.1:11434/v1" | ||
api_key: str = "ollama" | ||
model_name: str = "llama2:13b-chat" | ||
max_tokens: int = 4096 | ||
temperature: float = 0.7 | ||
|
||
|
||
class NLUConfiguration(BaseModel): | ||
"""Configuration for Natural Language Understanding""" | ||
|
||
pipeline_type: str = "traditional" # Either 'traditional' or 'llm' | ||
traditional_settings: TraditionalNLUSettings = TraditionalNLUSettings() | ||
llm_settings: LLMSettings = LLMSettings() | ||
|
||
|
||
class Bot(BaseModel): | ||
"""Base schema for bot""" | ||
|
||
id: ObjectIdField = Field(validation_alias="_id", default=None) | ||
name: str | ||
config: Dict[str, Any] = {} | ||
|
||
model_config = ConfigDict(arbitrary_types_allowed=True) | ||
nlu_config: NLUConfiguration = NLUConfiguration() | ||
created_at: Optional[datetime] = None | ||
updated_at: Optional[datetime] = None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from .crf_entity_extractor import CRFEntityExtractor | ||
from .synonym_replacer import SynonymReplacer | ||
|
||
__all__ = ["CRFEntityExtractor"] | ||
__all__ = ["CRFEntityExtractor", "SynonymReplacer"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import logging | ||
from typing import Dict, Any, Optional | ||
from app.bot.nlu.pipeline import NLUComponent | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class SynonymReplacer(NLUComponent): | ||
""" | ||
Replaces extracted entity values with their root words | ||
using a predefined synonyms dictionary. | ||
""" | ||
|
||
def __init__(self, synonyms: Optional[Dict[str, str]] = None): | ||
self.synonyms = synonyms or {} | ||
|
||
def replace_synonyms(self, entities: Dict[str, str]) -> Dict[str, str]: | ||
""" | ||
Replace extracted entity values with root words by matching with synonyms dict. | ||
:param entities: Dictionary of entity name to entity value mappings | ||
:return: Dictionary with replaced entity values where applicable | ||
""" | ||
for entity in entities.keys(): | ||
entity_value = str(entities[entity]) | ||
if entity_value.lower() in self.synonyms: | ||
entities[entity] = self.synonyms[entity_value.lower()] | ||
return entities | ||
|
||
def train(self, training_data: Dict[str, Any], model_path: str) -> None: | ||
"""Nothing to train for synonym replacement.""" | ||
pass | ||
|
||
def load(self, model_path: str) -> bool: | ||
"""Nothing to load for synonym replacement.""" | ||
return True | ||
|
||
def process(self, message: Dict[str, Any]) -> Dict[str, Any]: | ||
"""Process a message by replacing entity values with their synonyms.""" | ||
if not message.get("entities"): | ||
return message | ||
|
||
entities = message["entities"] | ||
message["entities"] = self.replace_synonyms(entities) | ||
return message |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .zero_shot_nlu_openai import ZeroShotNLUOpenAI | ||
|
||
__all__ = ["ZeroShotNLUOpenAI"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
You are provided with a text input. Your task is to analyze the given input and extract specific details based on the following instructions: | ||
|
||
1. **Identify the Intent**: Determine the intent of the input from the following options: | ||
{% for intent in intents %} | ||
- {{ intent }} | ||
{% endfor %} | ||
2. **Extract Entities**: Extract the following entities only if they are explicitly mentioned in the text: | ||
{% for entity in entities %} | ||
- {{ entity }} | ||
{% endfor %} | ||
3. **Strict Extraction Rules**: | ||
- Do not infer or guess any values. If an entity is not mentioned, assign it a value of null. | ||
- Ensure that the output is strictly in JSON format. | ||
- Output only the JSON object. Do not include any additional text, explanations, or comments. | ||
- Ensure that the JSON structure is valid and properly formatted. | ||
4. **Output Format**: Provide the output in the following JSON structure: | ||
{% raw %} | ||
```json | ||
{{ | ||
"intent": "<intent_value>" or null, | ||
"entities": {{ | ||
"entity_name_1": "<value>" or null, | ||
"entity_name_2": "<value>" or null, | ||
}} | ||
}} | ||
``` | ||
{% endraw %} |
Oops, something went wrong.