Skip to content

Commit

Permalink
feat(assistant): add streaming (#400)
Browse files Browse the repository at this point in the history
* Added feature assistant-streaming

* Fixed THREAD_MESSAGE_CREATED data type

---------

Co-authored-by: Nicholas Dalton <[email protected]>
  • Loading branch information
Daltomon and Nicholas Dalton authored Feb 1, 2025
1 parent 4517677 commit 04ab4a6
Show file tree
Hide file tree
Showing 12 changed files with 593 additions and 2 deletions.
54 changes: 54 additions & 0 deletions guides/GettingStarted.md
Original file line number Diff line number Diff line change
Expand Up @@ -849,3 +849,57 @@ val runSteps = openAI.runSteps(
runId = RunId("run_abc123")
)
```

### Event streaming

Create a thread and run it in one request and process streaming events.

```kotlin
openAI.createStreamingThreadRun(
request = ThreadRunRequest(
assistantId = AssistantId("asst_abc123"),
thread = ThreadRequest(
messages = listOf(
ThreadMessage(
role = Role.User,
content = "Explain deep learning to a 5 year old."
)
)
),
)
.onEach { assistantStreamEvent: AssistantStreamEvent -> println(assistantStreamEvent) }
.collect()
)
```

Get data object from AssistantStreamEvent.

```kotlin
//Type of data for generic type can be found in AssistantStreamEventType
when(assistantStreamEvent.type) {
AssistantStreamEventType.THREAD_CREATED -> {
val thread = assistantStreamEvent.getData<Thread>()
...
}
AssistantStreamEventType.MESSAGE_CREATED -> {
val message = assistantStreamEvent.getData<Message>()
...
}
AssistantStreamEventType.UNKNOWN -> {
//Data field is a string and can be used instead of calling getData
val data = assistantStreamEvent.data
//Handle unknown message type
}
}
```

If a new event type is released before the library is updated, you can create and deserialize your own type by providing a KSerializer.

```kotlin
when(assistantStreamEvent.type) {
AssistantStreamEventType.UNKNOWN -> {
val data = assistantStreamEvent.getDate<MyCustomType>(myCustomSerializer)
...
}
}
```
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import com.aallam.openai.api.core.SortOrder
import com.aallam.openai.api.core.Status
import com.aallam.openai.api.run.*
import com.aallam.openai.api.thread.ThreadId
import io.ktor.sse.ServerSentEvent
import kotlinx.coroutines.flow.Flow

/**
* Represents an execution run on a thread.
Expand All @@ -23,6 +25,21 @@ public interface Runs {
@BetaOpenAI
public suspend fun createRun(threadId: ThreadId, request: RunRequest, requestOptions: RequestOptions? = null): Run

/**
* Create a run with event streaming.
*
* @param threadId The ID of the thread to run
* @param request request for a run
* @param requestOptions request options.
* @param block a lambda function that will be called for each event.
*/
@BetaOpenAI
public suspend fun createStreamingRun(
threadId: ThreadId,
request: RunRequest,
requestOptions: RequestOptions? = null
) : Flow<AssistantStreamEvent>

/**
* Retrieves a run.
*
Expand Down Expand Up @@ -92,6 +109,25 @@ public interface Runs {
requestOptions: RequestOptions? = null
): Run

/**
* When a run has the status: [Status.RequiresAction] and required action is [RequiredAction.SubmitToolOutputs],
* this endpoint can be used to submit the outputs from the tool calls once they're all completed.
* All outputs must be submitted in a single request using event streaming.
*
* @param threadId the ID of the thread to which this run belongs
* @param runId the ID of the run to submit tool outputs for
* @param output list of tool outputs to submit
* @param requestOptions request options.
* @param block a lambda function that will be called for each event.
*/
@BetaOpenAI
public suspend fun submitStreamingToolOutput(
threadId: ThreadId,
runId: RunId,
output: List<ToolOutput>,
requestOptions: RequestOptions? = null
) : Flow<AssistantStreamEvent>

/**
* Cancels a run that is [Status.InProgress].
*
Expand All @@ -111,6 +147,19 @@ public interface Runs {
@BetaOpenAI
public suspend fun createThreadRun(request: ThreadRunRequest, requestOptions: RequestOptions? = null): Run

/**
* Create a thread and run it in one request with event streaming.
*
* @param request request for a thread run
* @param requestOptions request options.
* @param block a lambda function that will be called for each event.
*/
@BetaOpenAI
public suspend fun createStreamingThreadRun(
request: ThreadRunRequest,
requestOptions: RequestOptions? = null
) : Flow<AssistantStreamEvent>

/**
* Retrieves a run step.
*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package com.aallam.openai.client.extension

import com.aallam.openai.api.run.AssistantStreamEvent
import com.aallam.openai.client.internal.JsonLenient
import kotlinx.serialization.KSerializer

/**
* Get the data of the [AssistantStreamEvent] using the provided [serializer] from the corresponding event type.
* @param <T> the type of the data.
* @throws IllegalStateException if the [AssistantStreamEvent] data is null.
* @throws ClassCastException if the [AssistantStreamEvent] data cannot be cast to the provided type.
*/
@Suppress("UNCHECKED_CAST")
public fun <T> AssistantStreamEvent.getData(): T {
return type
.let { it.serializer as? KSerializer<T> }
?.let(::getData)
?: throw IllegalStateException("Failed to decode ServerSentEvent: $rawType")
}


/**
* Get the data of the [AssistantStreamEvent] using the provided [serializer].
* @throws IllegalStateException if the [AssistantStreamEvent] data is null.
* @throws ClassCastException if the [AssistantStreamEvent] data cannot be cast to the provided type.
*/
public fun <T> AssistantStreamEvent.getData(serializer: KSerializer<T>): T =
data
?.let { JsonLenient.decodeFromString(serializer, it) }
?: throw IllegalStateException("ServerSentEvent data was null: $rawType")
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package com.aallam.openai.client.extension

import com.aallam.openai.api.run.AssistantStreamEvent
import com.aallam.openai.api.run.AssistantStreamEventType
import com.aallam.openai.client.internal.JsonLenient
import io.ktor.sse.ServerSentEvent
import kotlinx.serialization.KSerializer

/**
* Convert a [ServerSentEvent] to [AssistantStreamEvent]. Type will be [AssistantStreamEventType.UNKNOWN] if the event is null or unrecognized.
*/
internal fun ServerSentEvent.toAssistantStreamEvent() : AssistantStreamEvent =
AssistantStreamEvent(
event,
event
?.let(AssistantStreamEventType::fromEvent)
?:AssistantStreamEventType.UNKNOWN,
data
)
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import io.ktor.client.plugins.auth.*
import io.ktor.client.plugins.auth.providers.*
import io.ktor.client.plugins.contentnegotiation.*
import io.ktor.client.plugins.logging.*
import io.ktor.client.plugins.sse.SSE
import io.ktor.http.*
import io.ktor.serialization.kotlinx.*
import io.ktor.util.*
Expand Down Expand Up @@ -71,6 +72,8 @@ internal fun createHttpClient(config: OpenAIConfig): HttpClient {
exponentialDelay(config.retry.base, config.retry.maxDelay.inWholeMilliseconds)
}

install(SSE)

defaultRequest {
url(config.host.baseUrl)
config.host.queryParams.onEach { (key, value) -> url.parameters.appendIfNameAbsent(key, value) }
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.aallam.openai.client.internal.api

import com.aallam.openai.api.BetaOpenAI
import com.aallam.openai.api.core.PaginatedList
import com.aallam.openai.api.core.RequestOptions
import com.aallam.openai.api.core.SortOrder
Expand All @@ -13,20 +14,37 @@ import com.aallam.openai.client.internal.http.perform
import io.ktor.client.call.*
import io.ktor.client.request.*
import io.ktor.http.*
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.collect
import kotlinx.coroutines.flow.onEach

internal class RunsApi(val requester: HttpRequester) : Runs {
override suspend fun createRun(threadId: ThreadId, request: RunRequest, requestOptions: RequestOptions?): Run {
return requester.perform {
it.post {
url(path = "${ApiPath.Threads}/${threadId.id}/runs")
setBody(request)
setBody(request.copy(stream = false))
contentType(ContentType.Application.Json)
beta("assistants", 2)
requestOptions(requestOptions)
}.body()
}
}

@BetaOpenAI
override suspend fun createStreamingRun(threadId: ThreadId, request: RunRequest, requestOptions: RequestOptions?) : Flow<AssistantStreamEvent> {
return requester
.performSse {
url(path = "${ApiPath.Threads}/${threadId.id}/runs")
setBody(request.copy(stream = true))
contentType(ContentType.Application.Json)
accept(ContentType.Text.EventStream)
beta("assistants", 2)
requestOptions(requestOptions)
method = HttpMethod.Post
}
}

override suspend fun getRun(threadId: ThreadId, runId: RunId, requestOptions: RequestOptions?): Run {
return requester.perform {
it.get {
Expand Down Expand Up @@ -95,6 +113,25 @@ internal class RunsApi(val requester: HttpRequester) : Runs {
}
}

@BetaOpenAI
override suspend fun submitStreamingToolOutput(
threadId: ThreadId,
runId: RunId,
output: List<ToolOutput>,
requestOptions: RequestOptions?
) : Flow<AssistantStreamEvent> {
return requester
.performSse {
url(path = "${ApiPath.Threads}/${threadId.id}/runs/${runId.id}/submit_tool_outputs")
setBody(mapOf("tool_outputs" to output, "stream" to true))
contentType(ContentType.Application.Json)
accept(ContentType.Text.EventStream)
beta("assistants", 2)
requestOptions(requestOptions)
method = HttpMethod.Post
}
}

override suspend fun cancel(threadId: ThreadId, runId: RunId, requestOptions: RequestOptions?): Run {
return requester.perform {
it.post {
Expand All @@ -109,14 +146,32 @@ internal class RunsApi(val requester: HttpRequester) : Runs {
return requester.perform {
it.post {
url(path = "${ApiPath.Threads}/runs")
setBody(request)
setBody(request.copy(stream = false))
contentType(ContentType.Application.Json)
beta("assistants", 2)
requestOptions(requestOptions)
}.body()
}
}

@BetaOpenAI
override suspend fun createStreamingThreadRun(
request: ThreadRunRequest,
requestOptions: RequestOptions?
) : Flow<AssistantStreamEvent> {
return requester
.performSse {
url(path = "${ApiPath.Threads}/runs")
setBody(request.copy(stream = true))
contentType(ContentType.Application.Json)
accept(ContentType.Text.EventStream)
beta("assistants", 2)
requestOptions(requestOptions)
method = HttpMethod.Post
}
}


override suspend fun runStep(
threadId: ThreadId,
runId: RunId,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
package com.aallam.openai.client.internal.http

import com.aallam.openai.api.run.AssistantStreamEvent
import io.ktor.client.*
import io.ktor.client.plugins.sse.ClientSSESession
import io.ktor.client.request.*
import io.ktor.client.statement.*
import io.ktor.sse.ServerSentEvent
import io.ktor.util.reflect.*
import kotlinx.coroutines.flow.Flow

/**
* Http request performer.
Expand All @@ -15,6 +19,14 @@ internal interface HttpRequester : AutoCloseable {
*/
suspend fun <T : Any> perform(info: TypeInfo, block: suspend (HttpClient) -> HttpResponse): T

/**
* Perform an HTTP request and process emitted server-side events.
*
*/
suspend fun performSse(
builderBlock: HttpRequestBuilder.() -> Unit
): Flow<AssistantStreamEvent>

/**
* Perform an HTTP request and get a result.
*
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,26 @@
package com.aallam.openai.client.internal.http

import com.aallam.openai.api.exception.*
import com.aallam.openai.api.run.AssistantStreamEvent
import com.aallam.openai.client.extension.toAssistantStreamEvent
import com.aallam.openai.client.internal.api.ApiPath
import io.ktor.client.*
import io.ktor.client.call.*
import io.ktor.client.network.sockets.*
import io.ktor.client.plugins.*
import io.ktor.client.plugins.sse.ClientSSESession
import io.ktor.client.plugins.sse.sseSession
import io.ktor.client.request.*
import io.ktor.client.statement.*
import io.ktor.http.ContentType
import io.ktor.sse.ServerSentEvent
import io.ktor.util.reflect.*
import io.ktor.utils.io.errors.*
import kotlinx.coroutines.CancellationException
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.collect
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.flow.onEach

/** HTTP transport layer */
internal class HttpTransport(private val httpClient: HttpClient) : HttpRequester {
Expand All @@ -35,6 +46,19 @@ internal class HttpTransport(private val httpClient: HttpClient) : HttpRequester
}
}

override suspend fun performSse(
builderBlock: HttpRequestBuilder.() -> Unit
): Flow<AssistantStreamEvent> {
try {
return httpClient
.sseSession(block = builderBlock)
.incoming
.map(ServerSentEvent::toAssistantStreamEvent)
} catch (e: Exception) {
throw handleException(e)
}
}

override fun close() {
httpClient.close()
}
Expand Down
Loading

0 comments on commit 04ab4a6

Please sign in to comment.