Skip to content

Commit

Permalink
fix: fix context propagation (#190)
Browse files Browse the repository at this point in the history
* fix context propagation

* updated docs
  • Loading branch information
alfredfrancis authored Feb 2, 2025
1 parent cf06697 commit 389163b
Show file tree
Hide file tree
Showing 9 changed files with 77 additions and 38 deletions.
12 changes: 9 additions & 3 deletions app/admin/test/routes.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, HTTPException
from app.bot.dialogue_manager.models import UserMessage
from app.dependencies import get_dialogue_manager
from app.bot.dialogue_manager.dialogue_manager import DialogueManager
from app.bot.dialogue_manager.dialogue_manager import (
DialogueManager,
DialogueManagerException,
)

router = APIRouter(prefix="/test", tags=["test"])

Expand All @@ -20,5 +23,8 @@ async def chat(
user_message = UserMessage(
thread_id=body["thread_id"], text=body["text"], context=body["context"]
)
new_state = await dialogue_manager.process(user_message)
try:
new_state = await dialogue_manager.process(user_message)
except DialogueManagerException as e:
raise HTTPException(status_code=400, detail=str(e))
return new_state.to_dict()
12 changes: 9 additions & 3 deletions app/bot/channels/rest/routes.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, HTTPException
from app.bot.dialogue_manager.models import UserMessage
from app.dependencies import get_dialogue_manager
from app.bot.dialogue_manager.dialogue_manager import DialogueManager
from app.bot.dialogue_manager.dialogue_manager import (
DialogueManager,
DialogueManagerException,
)

router = APIRouter(prefix="/rest", tags=["rest"])

Expand All @@ -20,5 +23,8 @@ async def webbook(
user_message = UserMessage(
thread_id=body["thread_id"], text=body["text"], context=body["context"]
)
new_state = await dialogue_manager.process(user_message)
try:
new_state = await dialogue_manager.process(user_message)
except DialogueManagerException as e:
raise HTTPException(status_code=400, message=str(e))
return new_state.bot_message
73 changes: 48 additions & 25 deletions app/bot/dialogue_manager/dialogue_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,17 @@
ParameterModel,
UserMessage,
)
from app.bot.dialogue_manager.http_client import call_api
from app.bot.dialogue_manager.http_client import call_api, APICallExcetion
from app.config import app_config
from app.database import client

logger = logging.getLogger("dialogue_manager")


class DialogueManagerException(Exception):
pass


class DialogueManager:
def __init__(
self,
Expand Down Expand Up @@ -77,7 +81,9 @@ def update_model(self, models_dir):
Reloads ML models and synonyms.
"""
# Load models
self.nlu_pipeline.load(models_dir)
ok = self.nlu_pipeline.load(models_dir)
if not ok:
self.nlu_pipeline = None
logger.info("NLU Pipeline models updated")

async def process(self, message: UserMessage) -> State:
Expand All @@ -88,6 +94,11 @@ async def process(self, message: UserMessage) -> State:
:return: current state of the conversation including the bot response
"""

if self.nlu_pipeline is None:
raise DialogueManagerException(
"NLU pipeline is not initialized. Please build the models."
)

# Step 1: Get current state
current_state = await self.memory_saver.get(message.thread_id)

Expand Down Expand Up @@ -249,9 +260,6 @@ def _process_intent(
entities_by_type[param.type].pop(0)
)

# Update context with extracted_parameters
current_state.context["parameters"] = current_state.extracted_parameters

# Handle missing parameters
current_state = self._handle_missing_parameters(parameters, current_state)

Expand Down Expand Up @@ -301,57 +309,72 @@ async def _handle_api_trigger(
"""
if intent.api_trigger and intent.api_details:
try:
result = await self._call_intent_api(intent, current_state.context)
current_state.context["result"] = result
result = await self._call_intent_api(intent, current_state)
template = Template(
intent.speech_response,
undefined=SilentUndefined,
enable_async=True,
)
rendered_text = await template.render_async(**current_state.context)
rendered_text = await template.render_async(
context=current_state.context,
parameters=current_state.extracted_parameters,
result=result,
)

current_state.bot_message = [
{"text": msg} for msg in split_sentence(rendered_text)
]
except Exception as e:

except DialogueManagerException as e:
logger.warning(f"API call failed: {e}")
current_state.bot_message = [
{"text": "Service is not available. Please try again later."}
]
else:
current_state.context["result"] = {}
template = Template(
intent.speech_response,
undefined=SilentUndefined,
enable_async=True,
)
rendered_text = await template.render_async(**current_state.context)
rendered_text = await template.render_async(
context=current_state.context,
parameters=current_state.extracted_parameters,
)
current_state.bot_message = [
{"text": msg} for msg in split_sentence(rendered_text)
]

return current_state

async def _call_intent_api(self, intent: IntentModel, context: Dict):
async def _call_intent_api(self, intent: IntentModel, current_state: State):
"""
Call the API associated with the intent.
"""
api_details = intent.api_details
headers = api_details.get_headers()
url_template = Template(api_details.url, undefined=SilentUndefined)
rendered_url = url_template.render(**context)

rendered_url = url_template.render(
context=current_state.context, parameters=current_state.extracted_parameters
)
if api_details.is_json:
request_template = Template(
api_details.json_data, undefined=SilentUndefined
)
parameters = json.loads(request_template.render(**context))
request_json = request_template.render(
context=current_state.context,
parameters=current_state.extracted_parameters,
)
parameters = json.loads(request_json)
else:
parameters = context.get("parameters", {})

return await call_api(
rendered_url,
api_details.request_type,
headers,
parameters,
api_details.is_json,
)
parameters = current_state.extracted_parameters

try:
return await call_api(
rendered_url,
api_details.request_type,
headers,
parameters,
api_details.is_json,
)
except APICallExcetion as e:
logger.warning(f"API call failed: {e}")
raise DialogueManagerException("API call failed")
8 changes: 6 additions & 2 deletions app/bot/dialogue_manager/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
logger = logging.getLogger("http_client")


class APICallExcetion(Exception):
pass


async def call_api(
url: str,
method: str,
Expand Down Expand Up @@ -73,10 +77,10 @@ async def call_api(

except aiohttp.ClientError as e:
logger.error(f"HTTP error occurred: {str(e)}")
raise
raise APICallExcetion(f"HTTP error occurred: {str(e)}")
except asyncio.TimeoutError:
logger.error(f"Request timed out after {timeout} seconds")
raise
raise APICallExcetion(f"Request timed out after {timeout} seconds")
except Exception as e:
logger.error(f"Unexpected error during API call: {str(e)}")
raise
1 change: 1 addition & 0 deletions app/bot/memory/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def from_dict(cls, state_dict: Dict) -> "State":
def update(self, user_message: UserMessage):
self.user_message = user_message
self.date = datetime.now(UTC)
self.context.update(user_message.context)

if self.complete:
self.bot_message = []
Expand Down
5 changes: 2 additions & 3 deletions app/bot/nlu/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,10 @@ def train(self, training_data: List[Dict[str, Any]], model_path: str) -> None:

def load(self, model_path: str) -> bool:
"""Load all components from model path."""
success = True
for component in self.components:
if not component.load(model_path):
success = False
return success
return False
return True

def process(self, message: Dict[str, Any]) -> Dict[str, Any]:
"""Process message through all components in sequence."""
Expand Down
Binary file modified docs/screenshots/admin_chat_screenshot.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion examples/order_status.json
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"intents": [{"name": "Default Fallback intent", "userDefined": false, "intentId": "fallback", "apiTrigger": false, "apiDetails": null, "speechResponse": "{{ \n\n[\n \"Sorry ### I'm having trouble understanding you.\",\n \"Hmm ### I cant handle that yet.\",\n \"Can you please re-phrase your query ?\"\n] | random \n\n}}\ufeff\n\n", "parameters": [], "labeledSentences": [], "trainingData": []}, {"name": "cancel", "userDefined": false, "intentId": "cancel", "apiTrigger": false, "apiDetails": null, "speechResponse": "Ok. Canceled.", "parameters": [], "labeledSentences": [], "trainingData": [{"text": "i want to cancel", "entities": []}, {"text": "cancel that", "entities": []}, {"text": "cancel", "entities": []}]}, {"name": "Welcome message", "userDefined": false, "intentId": "init_conversation", "apiTrigger": false, "apiDetails": null, "speechResponse": "Hi {{context[\"username\"] }} ### What can i do for you ?", "parameters": [], "labeledSentences": [], "trainingData": [{"text": "hello there", "entities": []}, {"text": "hey there", "entities": []}, {"text": "hii", "entities": []}, {"text": "heyy", "entities": []}, {"text": "howdy", "entities": []}, {"text": "hey", "entities": []}, {"text": "hello", "entities": []}, {"text": "hi", "entities": []}]}, {"name": "Check Order Status", "userDefined": true, "intentId": "check_order_status", "apiTrigger": true, "apiDetails": {"url": "https://fake-store-api.mock.beeceptor.com/api/orders/status?order_id={{ parameters['order_number'] }}", "requestType": "GET", "headers": [], "isJson": false, "jsonData": ""}, "speechResponse": "Your order status is {{ result['status'] }}", "parameters": [{"name": "order_number", "required": true, "type": "order_number", "prompt": "Please provide your order number"}], "labeledSentences": [], "trainingData": [{"text": "my order number is ORD113134", "entities": [{"value": "ORD113134", "begin": 19, "end": 28, "name": "order_number"}]}, {"text": "What is my order status", "entities": []}, {"text": "I want to know about order ORD123456", "entities": [{"value": "ORD123456", "begin": 27, "end": 36, "name": "order_number"}]}, {"text": "Can you check order ORD789012 for me ?", "entities": [{"value": "ORD789012", "begin": 20, "end": 29, "name": "order_number"}]}, {"text": "Where is my order ORD123456 ?", "entities": [{"value": "ORD123456", "begin": 18, "end": 27, "name": "order_number"}]}, {"text": "Track order ORDER789012", "entities": [{"value": "ORDER789012", "begin": 12, "end": 23, "name": "order_number"}]}, {"text": "What's the status of my order ORD123456 ?", "entities": [{"value": "ORD123456", "begin": 30, "end": 39, "name": "order_number"}]}, {"text": "Tell me order status", "entities": []}]}], "entities": [{"name": "order_number", "entity_values": []}]}
{"intents": [{"name": "Default Fallback intent", "userDefined": false, "intentId": "fallback", "apiTrigger": false, "apiDetails": null, "speechResponse": "{{ \n\n[\n \"Sorry ### I'm having trouble understanding you.\",\n \"Hmm ### I cant handle that yet.\",\n \"Can you please re-phrase your query ?\"\n] | random \n\n}}\ufeff\n\n", "parameters": [], "labeledSentences": [], "trainingData": []}, {"name": "cancel", "userDefined": false, "intentId": "cancel", "apiTrigger": false, "apiDetails": null, "speechResponse": "Ok. Canceled.", "parameters": [], "labeledSentences": [], "trainingData": [{"text": "i want to cancel", "entities": []}, {"text": "cancel that", "entities": []}, {"text": "cancel", "entities": []}]}, {"name": "Welcome message", "userDefined": false, "intentId": "init_conversation", "apiTrigger": false, "apiDetails": null, "speechResponse": "Hi {{context[\"username\"] }} ### What can i do for you ?", "parameters": [], "labeledSentences": [], "trainingData": [{"text": "hello there", "entities": []}, {"text": "hey there", "entities": []}, {"text": "hii", "entities": []}, {"text": "heyy", "entities": []}, {"text": "howdy", "entities": []}, {"text": "hey", "entities": []}, {"text": "hello", "entities": []}, {"text": "hi", "entities": []}]}, {"name": "Check Order Status", "userDefined": true, "intentId": "check_order_status", "apiTrigger": true, "apiDetails": {"url": "https://fake-store-api.mock.beeceptor.com/api/orders/status?order_id={{ parameters['order_number'] }}", "requestType": "GET", "headers": [], "isJson": false, "jsonData": ""}, "speechResponse": " Let me check the status of order {{ parameters.order_number }} ###\n Your order status is <b>{{ result.status }}</b> and is expected to arrive in 2-3 business days.", "parameters": [{"name": "order_number", "required": true, "type": "order_number", "prompt": "Sure ### Can you please give me your order number ?"}], "labeledSentences": [], "trainingData": [{"text": "my order number is ORD113134", "entities": [{"value": "ORD113134", "begin": 19, "end": 28, "name": "order_number"}]}, {"text": "What is my order status", "entities": []}, {"text": "I want to know about order ORD123456", "entities": [{"value": "ORD123456", "begin": 27, "end": 36, "name": "order_number"}]}, {"text": "Can you check order ORD789012 for me ?", "entities": [{"value": "ORD789012", "begin": 20, "end": 29, "name": "order_number"}]}, {"text": "Where is my order ORD123456 ?", "entities": [{"value": "ORD123456", "begin": 18, "end": 27, "name": "order_number"}]}, {"text": "Track order ORDER789012", "entities": [{"value": "ORDER789012", "begin": 12, "end": 23, "name": "order_number"}]}, {"text": "What's the status of my order ORD123456 ?", "entities": [{"value": "ORD123456", "begin": 30, "end": 39, "name": "order_number"}]}, {"text": "Tell me order status", "entities": []}]}], "entities": [{"name": "order_number", "entity_values": []}]}
2 changes: 1 addition & 1 deletion frontend/.env.development
Original file line number Diff line number Diff line change
@@ -1 +1 @@
NEXT_PUBLIC_API_BASE_URL=http://localhost:8080/
NEXT_PUBLIC_API_BASE_URL=http://127.0.0.1:8080

0 comments on commit 389163b

Please sign in to comment.