Skip to content

Commit

Permalink
AIX: dispatchers reuse the same doc wrapping
Browse files Browse the repository at this point in the history
  • Loading branch information
enricoros committed Jan 10, 2025
1 parent 08437f1 commit 213de18
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 25 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { escapeXml } from '~/server/wire';

import type { AixAPI_Model, AixAPIChatGenerate_Request, AixMessages_ChatMessage, AixParts_MetaInReferenceToPart, AixTools_ToolDefinition, AixTools_ToolsPolicy } from '../../../api/aix.wiretypes';
import type { AixAPI_Model, AixAPIChatGenerate_Request, AixMessages_ChatMessage, AixParts_DocPart, AixParts_MetaInReferenceToPart, AixTools_ToolDefinition, AixTools_ToolsPolicy } from '../../../api/aix.wiretypes';
import { AnthropicWire_API_Message_Create, AnthropicWire_Blocks } from '../../wiretypes/anthropic.wiretypes';


Expand All @@ -21,6 +21,15 @@ export function aixToAnthropicMessageCreate(model: AixAPI_Model, chatGenerate: A
if (chatGenerate.systemMessage?.parts.length) {
systemMessage = chatGenerate.systemMessage.parts.reduce((acc, part) => {
switch (part.pt) {

case 'text':
acc.push(AnthropicWire_Blocks.TextBlock(part.text));
break;

case 'doc':
acc.push(AnthropicWire_Blocks.TextBlock(approxDocPart_To_String(part)));
break;

case 'meta_cache_control':
if (!acc.length)
console.warn('Anthropic: cache_control without a message to attach to');
Expand All @@ -29,12 +38,16 @@ export function aixToAnthropicMessageCreate(model: AixAPI_Model, chatGenerate: A
else
AnthropicWire_Blocks.blockSetCacheControl(acc[acc.length - 1], 'ephemeral');
break;
case 'text':
acc.push(AnthropicWire_Blocks.TextBlock(part.text));
break;

default:
throw new Error(`Unsupported part type in System message: ${(part as any).pt}`);
}
return acc;
}, [] as Exclude<TRequest['system'], undefined>);

// unset system message if empty
if (!systemMessage.length)
systemMessage = undefined;
}

// Transform the chat messages into Anthropic's format
Expand Down Expand Up @@ -91,7 +104,7 @@ export function aixToAnthropicMessageCreate(model: AixAPI_Model, chatGenerate: A
// Top-P instead of temperature
if (model.topP !== undefined) {
payload.top_p = model.topP;
delete payload.temperature
delete payload.temperature;
}

// Preemptive error detection with server-side payload validation before sending it upstream
Expand Down Expand Up @@ -136,11 +149,11 @@ function* _generateAnthropicMessagesContentBlocks({ parts, role }: AixMessages_C
break;

case 'doc':
yield { role: 'user', content: AnthropicWire_Blocks.TextBlock('```' + (part.ref || '') + '\n' + part.data.text + '\n```\n') };
yield { role: 'user', content: AnthropicWire_Blocks.TextBlock(approxDocPart_To_String(part)) };
break;

case 'meta_in_reference_to':
const irtXMLString = inReferenceTo_To_XMLString(part);
const irtXMLString = approxInReferenceTo_To_XMLString(part);
if (irtXMLString)
yield { role: 'user', content: AnthropicWire_Blocks.TextBlock(irtXMLString) };
break;
Expand Down Expand Up @@ -259,7 +272,22 @@ function _toAnthropicToolChoice(itp: AixTools_ToolsPolicy): NonNullable<TRequest
}
}

export function inReferenceTo_To_XMLString(irt: AixParts_MetaInReferenceToPart): string | null {

// Approximate conversions - alternative approaches should be tried until we find the best one

export function approxDocPart_To_String({ ref, data }: AixParts_DocPart /*, wrapFormat?: 'markdown-code'*/): string {
// NOTE: Consider a better representation here
//
// We use the 'legacy' markdown encoding, but we may consider:
// - '<doc id='ref' title='title' version='version'>\n...\n</doc>'
// - ```doc id='ref' title='title' version='version'\n...\n```
// - # Title [id='ref' version='version']\n...\n
// - ...more ideas...
//
return '```' + (ref || '') + '\n' + data.text + '\n```\n';
}

export function approxInReferenceTo_To_XMLString(irt: AixParts_MetaInReferenceToPart): string | null {
const refs = irt.referTo.map(r => escapeXml(r.mText));
if (!refs.length)
return null; // `<context>User provides no specific references</context>`;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import type { AixAPI_Model, AixAPIChatGenerate_Request, AixMessages_ChatMessage, AixParts_DocPart, AixTools_ToolDefinition, AixTools_ToolsPolicy } from '../../../api/aix.wiretypes';
import { GeminiWire_API_Generate_Content, GeminiWire_ContentParts, GeminiWire_Messages, GeminiWire_Safety, GeminiWire_ToolDeclarations } from '../../wiretypes/gemini.wiretypes';

import { inReferenceTo_To_XMLString } from './anthropic.messageCreate';
import { approxDocPart_To_String, approxInReferenceTo_To_XMLString } from './anthropic.messageCreate';


// configuration
Expand All @@ -18,17 +18,26 @@ export function aixToGeminiGenerateContent(model: AixAPI_Model, chatGenerate: Ai
if (chatGenerate.systemMessage?.parts.length) {
systemInstruction = chatGenerate.systemMessage.parts.reduce((acc, part) => {
switch (part.pt) {
case 'meta_cache_control':
// ignore - we implement caching in the Anthropic way for now
break;

case 'text':
acc.parts.push(GeminiWire_ContentParts.TextPart(part.text));
break;

case 'doc':
acc.parts.push(GeminiWire_ContentParts.TextPart(approxDocPart_To_String(part)));
break;

case 'meta_cache_control':
// ignore - we implement caching in the Anthropic way for now
break;

default:
throw new Error(`Unsupported part type in System message: ${(part as any).pt}`);
}
return acc;
}, { parts: [] } as Exclude<TRequest['systemInstruction'], undefined>);

// unset system instructions with no parts
// unset system instruction if empty
if (!systemInstruction.parts.length)
systemInstruction = undefined;
}
Expand Down Expand Up @@ -120,7 +129,7 @@ function _toGeminiContents(chatSequence: AixMessages_ChatMessage[]): GeminiWire_
break;

case 'meta_in_reference_to':
const irtXMLString = inReferenceTo_To_XMLString(part);
const irtXMLString = approxInReferenceTo_To_XMLString(part);
if (irtXMLString)
parts.push(GeminiWire_ContentParts.TextPart(irtXMLString));
break;
Expand Down Expand Up @@ -291,5 +300,6 @@ function _toGeminiSafetySettings(threshold: GeminiWire_Safety.HarmBlockThreshold
// Approximate conversions - alternative approaches should be tried until we find the best one

function _toApproximateGeminiDocPart(aixPartsDocPart: AixParts_DocPart): GeminiWire_ContentParts.ContentPart {
return GeminiWire_ContentParts.TextPart(`\`\`\`${aixPartsDocPart.ref || ''}\n${aixPartsDocPart.data.text}\n\`\`\`\n`);
// NOTE: we keep this function because we could use Gemini's different way to represent documents in the future...
return GeminiWire_ContentParts.TextPart(approxDocPart_To_String(aixPartsDocPart));
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ import type { OpenAIDialects } from '~/modules/llms/server/openai/openai.router'
import type { AixAPI_Model, AixAPIChatGenerate_Request, AixMessages_ChatMessage, AixMessages_SystemMessage, AixParts_DocPart, AixParts_MetaInReferenceToPart, AixTools_ToolDefinition, AixTools_ToolsPolicy } from '../../../api/aix.wiretypes';
import { OpenAIWire_API_Chat_Completions, OpenAIWire_ContentParts, OpenAIWire_Messages } from '../../wiretypes/openai.wiretypes';

import { approxDocPart_To_String } from './anthropic.messageCreate';


//
// OpenAI API - Chat Adapter - Implementation Notes
Expand Down Expand Up @@ -458,17 +460,11 @@ function _toApproximateOpanAIFlattenSystemMessage(texts: OpenAIWire_ContentParts
return texts.map(text => text.text).join(approxSystemMessageJoiner);
}

function _toApproximateOpenAIDocPart({ data, ref }: AixParts_DocPart): OpenAIWire_ContentParts.TextContentPart {
function _toApproximateOpenAIDocPart(part: AixParts_DocPart): OpenAIWire_ContentParts.TextContentPart {

// Corner case, low probability: if the content is already enclosed in triple-backticks, return it as-is
if (data.text.startsWith('```'))
return OpenAIWire_ContentParts.TextContentPart(data.text);

// TODO: consider a better representation here - we use the 'legacy' markdown encoding
// but we may as well support different ones in the future, such as:
// - '<doc id='ref' title='title' version='version'>\n...\n</doc>'
// - ```doc id='ref' title='title' version='version'\n...\n```
// - etc.
if (part.data.text.startsWith('```'))
return OpenAIWire_ContentParts.TextContentPart(part.data.text);

return OpenAIWire_ContentParts.TextContentPart(`\`\`\`${ref || ''}\n${data.text}\n\`\`\`\n`);
return OpenAIWire_ContentParts.TextContentPart(approxDocPart_To_String(part));
}

0 comments on commit 213de18

Please sign in to comment.