From 27a2d3ca75fcddf6e9df6083502beb3a551d54c1 Mon Sep 17 00:00:00 2001 From: Stephen Hodgson Date: Thu, 19 Sep 2024 17:24:25 -0400 Subject: [PATCH] com.openai.unity 8.3.0 (#298) - Refactored TypeExtensions and JsonSchema generation - Improved JsonSchema generation for enums and dictionaries - Ensured JsonSchema properly handles nullable types - Ensure that function args are not re-serialized and passed back into tool function for future calls --- Documentation~/README.md | 95 +++++++++- Runtime/Chat/ChatRequest.cs | 11 +- Runtime/Extensions/TypeExtensions.cs | 187 +++++++++++++++---- Runtime/Threads/CreateRunRequest.cs | 11 +- Runtime/Threads/CreateThreadAndRunRequest.cs | 11 +- Tests/TestFixture_00_02_Extensions.cs | 46 +++++ package.json | 2 +- 7 files changed, 310 insertions(+), 53 deletions(-) diff --git a/Documentation~/README.md b/Documentation~/README.md index 2736a9c5..69b87819 100644 --- a/Documentation~/README.md +++ b/Documentation~/README.md @@ -50,7 +50,7 @@ The recommended installation method is though the unity package manager and [Ope > Check out our new api docs! - :new: + ### Table of Contents @@ -88,7 +88,7 @@ The recommended installation method is though the unity package manager and [Ope - [Retrieve Run](#retrieve-thread-run) - [Modify Run](#modify-thread-run) - [Submit Tool Outputs to Run](#thread-submit-tool-outputs-to-run) - - [Structured Outputs](#thread-structured-outputs) :new: + - [Structured Outputs](#thread-structured-outputs) - [List Run Steps](#list-thread-run-steps) - [Retrieve Run Step](#retrieve-thread-run-step) - [Cancel Run](#cancel-thread-run) @@ -113,7 +113,7 @@ The recommended installation method is though the unity package manager and [Ope - [Streaming](#chat-streaming) - [Tools](#chat-tools) - [Vision](#chat-vision) - - [Structured Outputs](#chat-structured-outputs) :new: + - [Structured Outputs](#chat-structured-outputs) - [Json Mode](#chat-json-mode) - [Audio](#audio) - [Create Speech](#create-speech) @@ -840,7 +840,8 @@ public class MathStep To use, simply specify the `MathResponse` type as a generic constraint in either `CreateAssistantAsync`, `CreateRunAsync`, or `CreateThreadAndRunAsync`. ```csharp -var assistant = await OpenAIClient.AssistantsEndpoint.CreateAssistantAsync( +var api = new OpenAIClient(); +var assistant = await api.AssistantsEndpoint.CreateAssistantAsync( new CreateAssistantRequest( name: "Math Tutor", instructions: "You are a helpful math tutor. Guide the user through the solution step by step.", @@ -909,6 +910,81 @@ finally } ``` +You can also manually create json schema json string as well, but you will be responsible for deserializing your response data: + +```csharp +var api = new OpenAIClient(); +var mathSchema = new JsonSchema("math_response", @" +{ + ""type"": ""object"", + ""properties"": { + ""steps"": { + ""type"": ""array"", + ""items"": { + ""type"": ""object"", + ""properties"": { + ""explanation"": { + ""type"": ""string"" + }, + ""output"": { + ""type"": ""string"" + } + }, + ""required"": [ + ""explanation"", + ""output"" + ], + ""additionalProperties"": false + } + }, + ""final_answer"": { + ""type"": ""string"" + } + }, + ""required"": [ + ""steps"", + ""final_answer"" + ], + ""additionalProperties"": false +}"); +var assistant = await api.AssistantsEndpoint.CreateAssistantAsync( + new CreateAssistantRequest( + name: "Math Tutor", + instructions: "You are a helpful math tutor. Guide the user through the solution step by step.", + model: "gpt-4o-2024-08-06", + jsonSchema: mathSchema)); +ThreadResponse thread = null; + +try +{ + var run = await assistant.CreateThreadAndRunAsync("how can I solve 8x + 7 = -23", + async @event => + { + Debug.Log(@event.ToJsonString()); + await Task.CompletedTask; + }); + thread = await run.GetThreadAsync(); + run = await run.WaitForStatusChangeAsync(); + Debug.Log($"Created thread and run: {run.ThreadId} -> {run.Id} -> {run.CreatedAt}"); + var messages = await thread.ListMessagesAsync(); + + foreach (var response in messages.Items) + { + Debug.Log($"{response.Role}: {response.PrintContent()}"); + } +} +finally +{ + await assistant.DeleteAsync(deleteToolResources: thread == null); + + if (thread != null) + { + var isDeleted = await thread.DeleteAsync(deleteToolResources: true); + Assert.IsTrue(isDeleted); + } +} +``` + ###### [List Thread Run Steps](https://platform.openai.com/docs/api-reference/runs/listRunSteps) Returns a list of run steps belonging to a run. @@ -964,7 +1040,7 @@ Returns a list of vector stores. ```csharp var api = new OpenAIClient(); -var vectorStores = await OpenAIClient.VectorStoresEndpoint.ListVectorStoresAsync(); +var vectorStores = await api.VectorStoresEndpoint.ListVectorStoresAsync(); foreach (var vectorStore in vectorStores.Items) { @@ -1313,6 +1389,7 @@ public class MathStep To use, simply specify the `MathResponse` type as a generic constraint when requesting a completion. ```csharp +var api = new OpenAIClient(); var messages = new List { new(Role.System, "You are a helpful math tutor. Guide the user through the solution step by step."), @@ -1320,7 +1397,7 @@ var messages = new List }; var chatRequest = new ChatRequest(messages, model: new("gpt-4o-2024-08-06")); -var (mathResponse, chatResponse) = await OpenAIClient.ChatEndpoint.GetCompletionAsync(chatRequest); +var (mathResponse, chatResponse) = await api.ChatEndpoint.GetCompletionAsync(chatRequest); for (var i = 0; i < mathResponse.Steps.Count; i++) { @@ -1383,7 +1460,7 @@ Generate streamed audio from the input text. ```csharp var api = new OpenAIClient(); var request = new SpeechRequest("Hello world!"); -var (path, clip) = await OpenAIClient.AudioEndpoint.CreateSpeechStreamAsync(request, partialClip => audioSource.PlayOneShot(partialClip)); +var (path, clip) = await api.AudioEndpoint.CreateSpeechStreamAsync(request, partialClip => audioSource.PlayOneShot(partialClip)); Debug.Log(path); ``` @@ -1538,7 +1615,7 @@ Returns information about a specific file. ```csharp var api = new OpenAIClient(); -var file = await api.FilesEndpoint.GetFileInfoAsync(fileId); +var file = await api.FilesEndpoint.GetFileInfoAsync(fileId); Debug.Log($"{file.Id} -> {file.Object}: {file.FileName} | {file.Size} bytes"); ``` @@ -1638,7 +1715,7 @@ List your organization's batches. ```csharp var api = new OpenAIClient(); -var batches = await api.await OpenAIClient.BatchEndpoint.ListBatchesAsync(); +var batches = await api.BatchEndpoint.ListBatchesAsync(); foreach (var batch in listResponse.Items) { diff --git a/Runtime/Chat/ChatRequest.cs b/Runtime/Chat/ChatRequest.cs index bded4eed..58be6102 100644 --- a/Runtime/Chat/ChatRequest.cs +++ b/Runtime/Chat/ChatRequest.cs @@ -38,7 +38,7 @@ public ChatRequest( { var toolList = tools?.ToList(); - if (toolList != null && toolList.Any()) + if (toolList is { Count: > 0 }) { if (string.IsNullOrWhiteSpace(toolChoice)) { @@ -59,6 +59,15 @@ public ChatRequest( ToolChoice = toolChoice; } } + + foreach (var tool in toolList) + { + if (tool?.Function?.Arguments != null) + { + // just in case clear any lingering func args. + tool.Function.Arguments = null; + } + } } Tools = toolList?.ToList(); diff --git a/Runtime/Extensions/TypeExtensions.cs b/Runtime/Extensions/TypeExtensions.cs index bbc64d02..64eb6bb7 100644 --- a/Runtime/Extensions/TypeExtensions.cs +++ b/Runtime/Extensions/TypeExtensions.cs @@ -4,6 +4,7 @@ using Newtonsoft.Json.Linq; using System; using System.Collections.Generic; +using System.Linq; using System.Reflection; using System.Threading; @@ -87,6 +88,7 @@ public static JObject GenerateJsonSchema(this Type type, JObject rootSchema, Jso { serializer ??= OpenAIClient.JsonSerializer; var schema = new JObject(); + type = UnwrapNullableType(type); if (!type.IsPrimitive && type != typeof(Guid) && @@ -98,60 +100,45 @@ public static JObject GenerateJsonSchema(this Type type, JObject rootSchema, Jso return new JObject { ["$ref"] = $"#/definitions/{type.FullName}" }; } - if (type == typeof(string)) + if (type.TryGetSimpleTypeSchema(out var schemaType)) { - schema["type"] = "string"; - } - else if (type == typeof(int) || - type == typeof(long) || - type == typeof(uint) || - type == typeof(byte) || - type == typeof(sbyte) || - type == typeof(ulong) || - type == typeof(short) || - type == typeof(ushort)) - { - schema["type"] = "integer"; - } - else if (type == typeof(float) || - type == typeof(double) || - type == typeof(decimal)) - { - schema["type"] = "number"; - } - else if (type == typeof(bool)) - { - schema["type"] = "boolean"; - } - else if (type == typeof(DateTime) || type == typeof(DateTimeOffset)) - { - schema["type"] = "string"; - schema["format"] = "date-time"; + schema["type"] = schemaType; + + if (type == typeof(DateTime) || + type == typeof(DateTimeOffset)) + { + schema["format"] = "date-time"; + } + else if (type == typeof(Guid)) + { + schema["format"] = "uuid"; + } } - else if (type == typeof(Guid)) + else if (type.IsEnum) { schema["type"] = "string"; - schema["format"] = "uuid"; + schema["enum"] = new JArray(Enum.GetNames(type).Select(JValue.CreateString).ToArray()); } - else if (type.IsEnum) + else if (type.TryGetDictionaryValueType(out var valueType)) { - schema["type"] = "string"; - schema["enum"] = new JArray(); + schema["type"] = "object"; - foreach (var value in Enum.GetValues(type)) + if (rootSchema["definitions"] != null && + ((JObject)rootSchema["definitions"]).ContainsKey(valueType.FullName!)) { - ((JArray)schema["enum"]).Add(JToken.FromObject(value, serializer)); + schema["additionalProperties"] = new JObject { ["$ref"] = $"#/definitions/{valueType.FullName}" }; + } + else + { + schema["additionalProperties"] = GenerateJsonSchema(valueType, rootSchema); } } - else if (type.IsArray || - type.IsGenericType && (type.GetGenericTypeDefinition() == typeof(List<>) || - type.GetGenericTypeDefinition() == typeof(IReadOnlyList<>))) + else if (type.TryGetCollectionElementType(out var elementType)) { schema["type"] = "array"; - var elementType = type.GetElementType() ?? type.GetGenericArguments()[0]; if (rootSchema["definitions"] != null && - ((JObject)rootSchema["definitions"]).ContainsKey(elementType.FullName!)) + ((JObject)rootSchema["definitions"]).ContainsKey(elementType.FullName!)) { schema["items"] = new JObject { ["$ref"] = $"#/definitions/{elementType.FullName}" }; } @@ -286,6 +273,126 @@ public static JObject GenerateJsonSchema(this Type type, JObject rootSchema, Jso return schema; } + private static bool TryGetSimpleTypeSchema(this Type type, out string schemaType) + { + switch (type) + { + case not null when type == typeof(object): + schemaType = "object"; + return true; + case not null when type == typeof(bool): + schemaType = "boolean"; + return true; + case not null when type == typeof(float) || + type == typeof(double) || + type == typeof(decimal): + schemaType = "number"; + return true; + case not null when type == typeof(char) || + type == typeof(string) || + type == typeof(Guid) || + type == typeof(DateTime) || + type == typeof(DateTimeOffset): + schemaType = "string"; + return true; + case not null when type == typeof(int) || + type == typeof(long) || + type == typeof(uint) || + type == typeof(byte) || + type == typeof(sbyte) || + type == typeof(ulong) || + type == typeof(short) || + type == typeof(ushort): + schemaType = "integer"; + return true; + default: + schemaType = null; + return false; + } + } + + private static bool TryGetDictionaryValueType(this Type type, out Type valueType) + { + valueType = null; + + if (!type.IsGenericType) { return false; } + + var genericTypeDefinition = type.GetGenericTypeDefinition(); + + if (genericTypeDefinition == typeof(Dictionary<,>) || + genericTypeDefinition == typeof(IDictionary<,>) || + genericTypeDefinition == typeof(IReadOnlyDictionary<,>)) + { + return InternalTryGetDictionaryValueType(type, out valueType); + } + + // Check implemented interfaces for dictionary types + foreach (var @interface in type.GetInterfaces()) + { + if (!@interface.IsGenericType) { continue; } + + var interfaceTypeDefinition = @interface.GetGenericTypeDefinition(); + + if (interfaceTypeDefinition == typeof(IDictionary<,>) || + interfaceTypeDefinition == typeof(IReadOnlyDictionary<,>)) + { + return InternalTryGetDictionaryValueType(@interface, out valueType); + } + } + + return false; + + bool InternalTryGetDictionaryValueType(Type dictType, out Type dictValueType) + { + dictValueType = null; + var genericArgs = dictType.GetGenericArguments(); + + // The key type is not string, which cannot be represented in JSON object property names + if (genericArgs[0] != typeof(string)) + { + throw new InvalidOperationException($"Cannot generate schema for dictionary type '{dictType.FullName}' with non-string key type."); + } + + dictValueType = genericArgs[1].UnwrapNullableType(); + return true; + } + } + + private static readonly Type[] arrayTypes = + { + typeof(IEnumerable<>), + typeof(ICollection<>), + typeof(IReadOnlyCollection<>), + typeof(List<>), + typeof(IList<>), + typeof(IReadOnlyList<>), + typeof(HashSet<>), + typeof(ISet<>) + }; + + private static bool TryGetCollectionElementType(this Type type, out Type elementType) + { + elementType = null; + + if (type.IsArray) + { + elementType = type.GetElementType(); + return true; + } + + if (!type.IsGenericType) { return false; } + + var genericTypeDefinition = type.GetGenericTypeDefinition(); + + if (!arrayTypes.Contains(genericTypeDefinition)) { return false; } + + elementType = type.GetGenericArguments()[0].UnwrapNullableType(); + return true; + } + + private static Type UnwrapNullableType(this Type type) + => Nullable.GetUnderlyingType(type) ?? type; + private static Type GetMemberType(MemberInfo member) => member switch { diff --git a/Runtime/Threads/CreateRunRequest.cs b/Runtime/Threads/CreateRunRequest.cs index 384f5612..087d2d12 100644 --- a/Runtime/Threads/CreateRunRequest.cs +++ b/Runtime/Threads/CreateRunRequest.cs @@ -150,7 +150,7 @@ public CreateRunRequest( var toolList = tools?.ToList(); - if (toolList != null && toolList.Any()) + if (toolList is { Count: > 0 }) { if (string.IsNullOrWhiteSpace(toolChoice)) { @@ -171,6 +171,15 @@ public CreateRunRequest( ToolChoice = toolChoice; } } + + foreach (var tool in toolList) + { + if (tool?.Function?.Arguments != null) + { + // just in case clear any lingering func args. + tool.Function.Arguments = null; + } + } } Tools = toolList?.ToList(); diff --git a/Runtime/Threads/CreateThreadAndRunRequest.cs b/Runtime/Threads/CreateThreadAndRunRequest.cs index 0f485a0a..a07c5a5c 100644 --- a/Runtime/Threads/CreateThreadAndRunRequest.cs +++ b/Runtime/Threads/CreateThreadAndRunRequest.cs @@ -150,7 +150,7 @@ public CreateThreadAndRunRequest( var toolList = tools?.ToList(); - if (toolList != null && toolList.Any()) + if (toolList is { Count: > 0 }) { if (string.IsNullOrWhiteSpace(toolChoice)) { @@ -171,6 +171,15 @@ public CreateThreadAndRunRequest( ToolChoice = toolChoice; } } + + foreach (var tool in toolList) + { + if (tool?.Function?.Arguments != null) + { + // just in case clear any lingering func args. + tool.Function.Arguments = null; + } + } } Tools = toolList?.ToList(); diff --git a/Tests/TestFixture_00_02_Extensions.cs b/Tests/TestFixture_00_02_Extensions.cs index fad7bc1d..751b1d0e 100644 --- a/Tests/TestFixture_00_02_Extensions.cs +++ b/Tests/TestFixture_00_02_Extensions.cs @@ -112,5 +112,51 @@ public void Test_02_01_GenerateJsonSchema() JsonSchema mathSchema = typeof(MathResponse); Debug.Log(mathSchema.ToString()); } + + [Test] + public void Test_02_02_GenerateJsonSchema_PrimitiveTypes() + { + JsonSchema schema = typeof(TestSchema); + Debug.Log(schema.ToString()); + } + + private class TestSchema + { + // test all primitive types can be serialized + public bool Bool { get; set; } + public byte Byte { get; set; } + public sbyte SByte { get; set; } + public short Short { get; set; } + public ushort UShort { get; set; } + public int Integer { get; set; } + public uint UInteger { get; set; } + public long Long { get; set; } + public ulong ULong { get; set; } + public float Float { get; set; } + public double Double { get; set; } + public decimal Decimal { get; set; } + public char Char { get; set; } + public string String { get; set; } + public DateTime DateTime { get; set; } + public DateTimeOffset DateTimeOffset { get; set; } + public Guid Guid { get; set; } + // test nullables + public int? NullInt { get; set; } + public DateTime? NullDateTime { get; set; } + public TestEnum TestEnum { get; set; } + public TestEnum? NullEnum { get; set; } + public Dictionary Dictionary { get; set; } + public IDictionary IntDictionary { get; set; } + public IReadOnlyDictionary StringDictionary { get; set; } + public Dictionary CustomDictionary { get; set; } + } + + private enum TestEnum + { + Enum1, + Enum2, + Enum3, + Enum4 + } } } diff --git a/package.json b/package.json index 5320d1f5..ccbbe73c 100644 --- a/package.json +++ b/package.json @@ -3,7 +3,7 @@ "displayName": "OpenAI", "description": "A OpenAI package for the Unity Game Engine to use GPT-4, GPT-3.5, GPT-3 and Dall-E though their RESTful API (currently in beta).\n\nIndependently developed, this is not an official library and I am not affiliated with OpenAI.\n\nAn OpenAI API account is required.", "keywords": [], - "version": "8.2.5", + "version": "8.3.0", "unity": "2021.3", "documentationUrl": "https://github.com/RageAgainstThePixel/com.openai.unity#documentation", "changelogUrl": "https://github.com/RageAgainstThePixel/com.openai.unity/releases",