Skip to content

Commit

Permalink
Merge pull request #5 from llm-agents-php/feature/embedding-generator
Browse files Browse the repository at this point in the history
Adds embedding generator
  • Loading branch information
butschster authored Sep 7, 2024
2 parents 2510748 + ae31687 commit dd22ae4
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 2 deletions.
7 changes: 5 additions & 2 deletions composer.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"require": {
"php": "^8.3",
"openai-php/client": "^0.10.1",
"llm-agents/agents": "^1.4",
"llm-agents/agents": "^1.5",
"guzzlehttp/guzzle": "^7.0"
},
"require-dev": {
Expand All @@ -24,7 +24,10 @@
}
},
"config": {
"sort-packages": true
"sort-packages": true,
"allow-plugins": {
"php-http/discovery": false
}
},
"extra": {
"laravel": {
Expand Down
34 changes: 34 additions & 0 deletions src/Embeddings/EmbeddingGenerator.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
<?php

declare(strict_types=1);

namespace LLM\Agents\OpenAI\Client\Embeddings;

use LLM\Agents\Embeddings\Document;
use LLM\Agents\Embeddings\Embedding;
use LLM\Agents\Embeddings\EmbeddingGeneratorInterface;
use OpenAI\Contracts\ClientContract;

final readonly class EmbeddingGenerator implements EmbeddingGeneratorInterface
{
public function __construct(
private ClientContract $client,
private OpenAIEmbeddingModel $model = OpenAIEmbeddingModel::TextEmbeddingAda002,
) {}

public function generate(Document ...$documents): array
{
$documents = \array_values($documents);

$response = $this->client->embeddings()->create([
'model' => $this->model->value,
'input' => \array_map(static fn(Document $doc): string => $doc->content, $documents),
]);

foreach ($response->embeddings as $i => $embedding) {
$documents[$i] = $documents[$i]->withEmbedding(new Embedding($embedding->embedding));
}

return $documents;
}
}
12 changes: 12 additions & 0 deletions src/Embeddings/OpenAIEmbeddingModel.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
<?php

declare(strict_types=1);

namespace LLM\Agents\OpenAI\Client\Embeddings;

enum OpenAIEmbeddingModel: string
{
case TextEmbedding3Small = 'text-embedding-3-small';
case TextEmbedding3Large = 'text-embedding-3-large';
case TextEmbeddingAda002 = 'text-embedding-ada-002';
}
22 changes: 22 additions & 0 deletions src/Integration/Laravel/OpenAIClientServiceProvider.php
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,14 @@

use Illuminate\Contracts\Foundation\Application;
use Illuminate\Support\ServiceProvider;
use LLM\Agents\Embeddings\EmbeddingGeneratorInterface;
use LLM\Agents\LLM\LLMInterface;
use LLM\Agents\OpenAI\Client\Embeddings\EmbeddingGenerator;
use LLM\Agents\OpenAI\Client\Embeddings\OpenAIEmbeddingModel;
use LLM\Agents\OpenAI\Client\LLM;
use LLM\Agents\OpenAI\Client\Parsers\ChatResponseParser;
use LLM\Agents\OpenAI\Client\StreamResponseParser;
use OpenAI\Contracts\ClientContract;
use OpenAI\Responses\Chat\CreateStreamedResponse;

final class OpenAIClientServiceProvider extends ServiceProvider
Expand All @@ -21,6 +25,24 @@ public function register(): void
LLM::class,
);

$this->app->singleton(
EmbeddingGeneratorInterface::class,
EmbeddingGenerator::class,
);

$this->app->singleton(
EmbeddingGenerator::class,
static function (
ClientContract $client,
): EmbeddingGenerator {
return new EmbeddingGenerator(
client: $client,
// todo: use config
model: OpenAIEmbeddingModel::TextEmbeddingAda002,
);
},
);

$this->app->singleton(
StreamResponseParser::class,
static function (Application $app): StreamResponseParser {
Expand Down
16 changes: 16 additions & 0 deletions src/Integration/Spiral/OpenAIClientBootloader.php
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
namespace LLM\Agents\OpenAI\Client\Integration\Spiral;

use GuzzleHttp\Client as HttpClient;
use LLM\Agents\Embeddings\EmbeddingGeneratorInterface;
use LLM\Agents\LLM\LLMInterface;
use LLM\Agents\OpenAI\Client\Embeddings\EmbeddingGenerator;
use LLM\Agents\OpenAI\Client\Embeddings\OpenAIEmbeddingModel;
use LLM\Agents\OpenAI\Client\LLM;
use LLM\Agents\OpenAI\Client\Parsers\ChatResponseParser;
use LLM\Agents\OpenAI\Client\StreamResponseParser;
Expand All @@ -20,6 +23,19 @@ public function defineSingletons(): array
{
return [
LLMInterface::class => LLM::class,
EmbeddingGeneratorInterface::class => EmbeddingGenerator::class,

EmbeddingGenerator::class => static function (
ClientContract $client,
EnvironmentInterface $env,
): EmbeddingGenerator {
return new EmbeddingGenerator(
client: $client,
model: OpenAIEmbeddingModel::from(
$env->get('OPENAI_EMBEDDING_MODEL', OpenAIEmbeddingModel::TextEmbedding3Small->value),
),
);
},

ClientContract::class => static fn(
EnvironmentInterface $env,
Expand Down

0 comments on commit dd22ae4

Please sign in to comment.