Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor/refactor and bug fixes #54

Merged
merged 2 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,12 @@ $env:API_KEY="your_api_key""
`codai` requires a `config.yml` file in the root of your working directory to analyze your project. By default, the `config.yml` contains the following values:
```yml
ai_provider_config:
provider_name: "openai"
provider_name: "openai"
chat_completion_url: "http://localhost:11434/v1/chat/completions"
chat_completion_model: "gpt-4o"
embedding_url: "http://localhost:11434/v1/embeddings" (Optional, If you want use RAG.)
embedding_model: "text-embedding-3-small" (Optional, If you want use RAG.)
temperature: 0.2
threshold: 0.3
theme: "dracula"
rag: true (Optional, If you want use RAG.)
```
Expand Down
Binary file modified assets/codai-demo.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
73 changes: 36 additions & 37 deletions cmd/code.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,15 @@ improved responses throughout the user experience.`,
func handleCodeCommand(rootDependencies *RootDependencies) {

// Create a context with cancel function
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
defer stop()

//Channel to signal when the application should shut down
done := make(chan bool)
spinner := pterm.DefaultSpinner.WithStyle(pterm.NewStyle(pterm.FgLightBlue)).WithSequence("⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏").WithDelay(100).WithRemoveWhenDone(true)

sigs := make(chan os.Signal, 1)
go utils.GracefulShutdown(ctx, stop, func() {

signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)

go utils.GracefulShutdown(ctx, done, sigs, func() {
rootDependencies.ChatHistory.ClearHistory()
rootDependencies.TokenManagement.ClearToken()
})

reader := bufio.NewReader(os.Stdin)
Expand All @@ -58,39 +55,45 @@ func handleCodeCommand(rootDependencies *RootDependencies) {
codeOptionsBox := lipgloss.BoxStyle.Render(":help Help for code subcommand")
fmt.Println(codeOptionsBox)

spinnerLoadContext, err := pterm.DefaultSpinner.WithStyle(pterm.NewStyle(pterm.FgLightBlue)).WithSequence("⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏").WithDelay(100).Start("Loading Context...")

// Get all data files from the root directory
fullContextFiles, fullContextCodes, err = rootDependencies.Analyzer.GetProjectFiles(rootDependencies.Cwd)

spinnerLoadContext, err := spinner.Start("Loading Context...")
if err != nil {
spinnerLoadContext.Stop()
fmt.Print("\r")
fmt.Println(lipgloss.Red.Render(fmt.Sprintf("%v", err)))
}

// Get all data files from the root directory
fullContextFiles, fullContextCodes, err = rootDependencies.Analyzer.GetProjectFiles(rootDependencies.Cwd)

spinnerLoadContext.Stop()
fmt.Print("\r")

// Launch the user input handler in a goroutine
startLoop: // Label for the start loop
for {
select {
case <-ctx.Done():
<-done // Wait for gracefulShutdown to complete
// Wait for GracefulShutdown to complete
return

default:

displayTokens := func() {
rootDependencies.TokenManagement.DisplayTokens(rootDependencies.Config.AIProviderConfig.ProviderName, rootDependencies.Config.AIProviderConfig.ChatCompletionModel, rootDependencies.Config.AIProviderConfig.EmbeddingModel, rootDependencies.Config.RAG)
}

// Get user input
userInput, err := utils.InputPrompt(reader)

if err != nil {
fmt.Println(lipgloss.Red.Render(fmt.Sprintf("%v", err)))
continue
}

if userInput == "" {
fmt.Print("\r")
continue
}

// Configure help code subcommand
isHelpSubcommands, exit := findCodeSubCommand(userInput, rootDependencies)

Expand All @@ -99,18 +102,17 @@ startLoop: // Label for the start loop
}

if exit {
cancel() // Initiate shutdown for the app's own ":exit" command
<-done // Wait for gracefulShutdown to complete
return
}

// If RAG is enabled, we use RAG system for retrieve most relevant data due user request
if rootDependencies.Config.RAG {

spinnerEmbeddingContext, err := spinner.Start("Embedding Context...")

var wg sync.WaitGroup
errorChan := make(chan error, len(fullContextFiles))

spinnerLoadContextEmbedding, err := pterm.DefaultSpinner.WithStyle(pterm.NewStyle(pterm.FgLightBlue)).WithSequence("⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏").WithDelay(100).Start("Embedding Context...")

for _, dataFile := range fullContextFiles {
wg.Add(1) // Increment the WaitGroup counter
go func(dataFile models.FileData) {
Expand All @@ -127,7 +129,7 @@ startLoop: // Label for the start loop
}

// Call the retryWithBackoff function with the operation and a 3-time retry
if err := utils.RetryWithBackoff(filesEmbeddingOperation, 3); err != nil {
if err := filesEmbeddingOperation(); err != nil {
errorChan <- err // Send the error to the channel
}
}(dataFile) // Pass the current dataFile to the Goroutine
Expand All @@ -137,7 +139,8 @@ startLoop: // Label for the start loop
close(errorChan) // Close the error channel
// Handle any errors that occurred during processing
for err = range errorChan {
spinnerLoadContextEmbedding.Stop()
spinnerEmbeddingContext.Stop()
fmt.Print("\r")
fmt.Println(lipgloss.Red.Render(fmt.Sprintf("%v", err)))
displayTokens()
continue startLoop
Expand All @@ -159,21 +162,20 @@ startLoop: // Label for the start loop
topN := -1

// Step 6: Find relevant code chunks based on the user query embedding
fullContextCodes = rootDependencies.Store.FindRelevantChunks(queryEmbedding[0], topN, rootDependencies.Config.AIProviderConfig.Threshold)
fullContextCodes = rootDependencies.Store.FindRelevantChunks(queryEmbedding[0], topN, rootDependencies.Config.AIProviderConfig.EmbeddingModel, rootDependencies.Config.AIProviderConfig.Threshold)
return nil
}

// Call the retryWithBackoff function with the operation and a 3 time retry
err = utils.RetryWithBackoff(queryEmbeddingOperation, 3)

if err != nil {
spinnerLoadContextEmbedding.Stop()
if err := queryEmbeddingOperation(); err != nil {
spinnerEmbeddingContext.Stop()
fmt.Print("\r")
fmt.Println(lipgloss.Red.Render(fmt.Sprintf("%v", err)))
displayTokens()
continue startLoop
}

spinnerLoadContextEmbedding.Stop()
spinnerEmbeddingContext.Stop()
fmt.Print("\r")
}

var aiResponseBuilder strings.Builder
Expand Down Expand Up @@ -206,10 +208,7 @@ startLoop: // Label for the start loop
return nil
}

// Call the retryWithBackoff function with the operation and a 3 time retry
err = utils.RetryWithBackoff(chatRequestOperation, 3)

if err != nil {
if err := chatRequestOperation(); err != nil {
fmt.Println(lipgloss.Red.Render(fmt.Sprintf("%v", err)))
displayTokens()
continue startLoop
Expand All @@ -224,9 +223,7 @@ startLoop: // Label for the start loop

fmt.Println(lipgloss.BlueSky.Render("\nThese files need to changes...\n"))

err = chatRequestOperation()

if err != nil {
if err := chatRequestOperation(); err != nil {
fmt.Println(lipgloss.Red.Render(fmt.Sprintf("%v", err)))
displayTokens()
continue
Expand Down Expand Up @@ -275,14 +272,16 @@ startLoop: // Label for the start loop
// If we need Update the context after apply changes
if updateContextNeeded {

spinnerUpdateContext, err := pterm.DefaultSpinner.WithStyle(pterm.NewStyle(pterm.FgLightBlue)).WithSequence("⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏").WithDelay(100).Start("Updating Context...")
spinnerUpdateContext, err := spinner.Start("Updating Context...")

fullContextFiles, fullContextCodes, err = rootDependencies.Analyzer.GetProjectFiles(rootDependencies.Cwd)
if err != nil {
spinnerUpdateContext.Stop()
fmt.Print("\r")
fmt.Println(lipgloss.Red.Render(fmt.Sprintf("%v", err)))
}
spinnerUpdateContext.Stop()
fmt.Print("\r")
}
displayTokens()
}
Expand All @@ -300,7 +299,7 @@ func findCodeSubCommand(command string, rootDependencies *RootDependencies) (boo
fmt.Print("\033[2J\033[H")
return true, false
case ":exit":
return false, false
return false, true
case ":token":
rootDependencies.TokenManagement.DisplayTokens(
rootDependencies.Config.AIProviderConfig.ProviderName,
Expand Down
12 changes: 9 additions & 3 deletions code_analyzer/analyzer.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,16 @@ func (analyzer *CodeAnalyzer) GetProjectFiles(rootDir string) ([]models.FileData
// Ensure that the current entry is a file, not a directory
if !d.IsDir() {

// Check file size
fileInfo, err := os.Stat(path)
if err != nil {
return err
return fmt.Errorf("failed to get file info: %s, error: %w", relativePath, err)
}
// Skip files over 100 KB (100 * 1024 bytes)
if fileInfo.Size() > 100*1024 {
return nil // Skip this file
}

if utils.IsGitIgnored(relativePath, gitIgnorePatterns) {
// Debugging: Print the ignored file
return nil // Skip this file
Expand Down Expand Up @@ -245,7 +252,7 @@ func (analyzer *CodeAnalyzer) TryGetInCompletedCodeBlocK(relativePaths string) (
}

func (analyzer *CodeAnalyzer) ExtractCodeChanges(diff string) []models.CodeChange {
filePathPattern := regexp.MustCompile(`(?i)(?:\d+\.\s*|File:\s*)(\S+\.[a-zA-Z0-9]+)`)
filePathPattern := regexp.MustCompile("(?i)(?:\\d+\\.\\s*|File:\\s*)[`']?([^\\s*`']+?\\.[a-zA-Z0-9]+)[`']?\\b")

lines := strings.Split(diff, "\n")
var fileChanges []models.CodeChange
Expand Down Expand Up @@ -333,7 +340,6 @@ func (analyzer *CodeAnalyzer) ApplyChanges(relativePath, diff string) error {
} else if strings.HasPrefix(trimmedLine, "+") {
// Add lines that start with "+", but remove the "+" symbol
updatedContent = append(updatedContent, strings.ReplaceAll(trimmedLine, "+", " "))

} else {
// Keep all other lines as they are
updatedContent = append(updatedContent, line)
Expand Down
10 changes: 0 additions & 10 deletions code_analyzer/analyzer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,16 +218,6 @@ func TestExtractCodeChangesComplexText(t *testing.T) {
assert.Equal(t, "import pygame\nimport random\n\n# Initialize pygame\npygame.init()\n\n# Screen dimensions\nSCREEN_WIDTH = 800\nSCREEN_HEIGHT = 600\n\n# Colors\nBLACK = (0, 0, 0)\nWHITE = (255, 255, 255)\nYELLOW = (255, 255, 0)\nRED = (255, 0, 0)\n\n# Pacman settings\nPACMAN_SIZE = 50\nPACMAN_SPEED = 5\n\n# Ghost settings\nGHOST_SIZE = 50\nGHOST_SPEED = 3\n\n# Create the screen\nscreen = pygame.display.set_mode((SCREEN_WIDTH, SCREEN_HEIGHT))\npygame.display.set_caption(\"Pacman Game\")\n\n# Load images\npacman_image = pygame.image.load(\"pacman.png\")\npacman_image = pygame.transform.scale(pacman_image, (PACMAN_SIZE, PACMAN_SIZE))\n\nghost_image = pygame.image.load(\"ghost.png\")\nghost_image = pygame.transform.scale(ghost_image, (GHOST_SIZE, GHOST_SIZE))\n\n# Pacman class\nclass Pacman:\n def __init__(self):\n self.x = SCREEN_WIDTH // 2\n self.y = SCREEN_HEIGHT // 2\n self.speed = PACMAN_SPEED\n self.image = pacman_image\n\n def move(self, dx, dy):\n self.x += dx * self.speed\n self.y += dy * self.speed\n\n # Boundary check\n if self.x < 0:\n self.x = 0\n elif self.x > SCREEN_WIDTH - PACMAN_SIZE:\n self.x = SCREEN_WIDTH - PACMAN_SIZE\n\n if self.y < 0:\n self.y = 0\n elif self.y > SCREEN_HEIGHT - PACMAN_SIZE:\n self.y = SCREEN_HEIGHT - PACMAN_SIZE\n\n def draw(self):\n screen.blit(self.image, (self.x, self.y))\n\n# Ghost class\nclass Ghost:\n def __init__(self):\n self.x = random.randint(0, SCREEN_WIDTH - GHOST_SIZE)\n self.y = random.randint(0, SCREEN_HEIGHT - GHOST_SIZE)\n self.speed = GHOST_SPEED\n self.image = ghost_image\n\n def move(self):\n self.x += random.choice([-1, 1]) * self.speed\n self.y += random.choice([-1, 1]) * self.speed\n\n # Boundary check\n if self.x < 0:\n self.x = 0\n elif self.x > SCREEN_WIDTH - GHOST_SIZE:\n self.x = SCREEN_WIDTH - GHOST_SIZE\n\n if self.y < 0:\n self.y = 0\n elif self.y > SCREEN_HEIGHT - GHOST_SIZE:\n self.y = SCREEN_HEIGHT - GHOST_SIZE\n\n def draw(self):\n screen.blit(self.image, (self.x, self.y))\n\n# Main game loop\ndef main():\n clock = pygame.time.Clock()\n pacman = Pacman()\n ghosts = [Ghost() for _ in range(4)]\n\n running = True\n while running:\n for event in pygame.event.get():\n if event.type == pygame.QUIT:\n running = False\n\n keys = pygame.key.get_pressed()\n dx = dy = 0\n if keys[pygame.K_LEFT]:\n dx = -1\n if keys[pygame.K_RIGHT]:\n dx = 1\n if keys[pygame.K_UP]:\n dy = -1\n if keys[pygame.K_DOWN]:\n dy = 1\n\n pacman.move(dx, dy)\n\n for ghost in ghosts:\n ghost.move()\n\n screen.fill(BLACK)\n pacman.draw()\n for ghost in ghosts:\n ghost.draw()\n\n pygame.display.flip()\n clock.tick(30)\n\n pygame.quit()\n\nif __name__ == \"__main__\":\n main()", codeChanges[0].Code)
}

// / Test for ExtractCodeChanges with standard input
func TestExtractCodeChangesComplexText2(t *testing.T) {
setup(t)
text := "Sure, I can help you create a simple \"Dino Game\" using Python and the Pygame library. This game will be similar to the Chrome Dino game where a dinosaur jumps over obstacles.\n\nFirst, ensure you have Pygame installed. You can install it using pip:\n```sh\npip install pygame\n```\n\nNow, let's create the game. We will need a few files:\n1. `main.py` - The main game logic.\n2. `dino.py` - The Dino character logic.\n3. `obstacle.py` - The obstacle logic.\n4. `settings.py` - Game settings and configurations.\n\n### File: `settings.py`\n```python\n# Game settings and configurations\n\nSCREEN_WIDTH = 800\nSCREEN_HEIGHT = 400\nFPS = 30\n\n# Colors\nWHITE = (255, 255, 255)\nBLACK = (0, 0, 0)\n\n# Dino settings\nDINO_WIDTH = 40\nDINO_HEIGHT = 60\nDINO_COLOR = (0, 255, 0)\nDINO_JUMP_HEIGHT = 10\n\n# Obstacle settings\nOBSTACLE_WIDTH = 20\nOBSTACLE_HEIGHT = 40\nOBSTACLE_COLOR = (255, 0, 0)\nOBSTACLE_SPEED = 5\n```\n\n### File: `dino.py`\n```python\nimport pygame\nfrom settings import DINO_WIDTH, DINO_HEIGHT, DINO_COLOR, DINO_JUMP_HEIGHT\n\nclass Dino:\n def __init__(self, x, y):\n self.rect = pygame.Rect(x, y, DINO_WIDTH, DINO_HEIGHT)\n self.color = DINO_COLOR\n self.is_jumping = False\n self.jump_speed = DINO_JUMP_HEIGHT\n self.gravity = 1\n\n def update(self):\n if self.is_jumping:\n self.rect.y -= self.jump_speed\n self.jump_speed -= self.gravity\n if self.jump_speed < -DINO_JUMP_HEIGHT:\n self.is_jumping = False\n self.jump_speed = DINO_JUMP_HEIGHT\n\n def draw(self, screen):\n pygame.draw.rect(screen, self.color, self.rect)\n\n def jump(self):\n if not self.is_jumping:\n self.is_jumping = True\n```\n\n### File: `obstacle.py`\n```python\nimport pygame\nfrom settings import OBSTACLE_WIDTH, OBSTACLE_HEIGHT, OBSTACLE_COLOR, OBSTACLE_SPEED\n\nclass Obstacle:\n def __init__(self, x, y):\n self.rect = pygame.Rect(x, y, OBSTACLE_WIDTH, OBSTACLE_HEIGHT)\n self.color = OBSTACLE_COLOR\n self.speed = OBSTACLE_SPEED\n\n def update(self):\n self.rect.x -= self.speed\n\n def draw(self, screen):\n pygame.draw.rect(screen, self.color, self.rect)\n```\n\n### File: `main.py`\n```python\nimport pygame\nimport random\nfrom settings import SCREEN_WIDTH, SCREEN_HEIGHT, FPS, WHITE\nfrom dino.py import Dino\nfrom obstacle.py import Obstacle\n\ndef main():\n pygame.init()\n screen = pygame.display.set_mode((SCREEN_WIDTH, SCREEN_HEIGHT))\n pygame.display.set_caption(\"Dino Game\")\n clock = pygame.time.Clock()\n\n dino = Dino(50, SCREEN_HEIGHT - 60)\n obstacles = []\n\n running = True\n while running:\n for event in pygame.event.get():\n if event.type == pygame.QUIT:\n running = False\n if event.type == pygame.KEYDOWN:\n if event.key == pygame.K_SPACE:\n dino.jump()\n\n screen.fill(WHITE)\n\n dino.update()\n dino.draw(screen)\n\n if random.randint(1, 100) < 2:\n obstacles.append(Obstacle(SCREEN_WIDTH, SCREEN_HEIGHT - 40))\n\n for obstacle in obstacles[:]:\n obstacle.update()\n obstacle.draw(screen)\n if obstacle.rect.x < 0:\n obstacles.remove(obstacle)\n if dino.rect.colliderect(obstacle.rect):\n running = False\n\n pygame.display.flip()\n clock.tick(FPS)\n\n pygame.quit()\n\nif __name__ == \"__main__\":\n main()\n```\n\nThis code sets up a basic Dino game where the dinosaur can jump over obstacles. The game will end if the dinosaur collides with an obstacle. You can expand and improve this game by adding more features, such as scoring, different types of obstacles, and animations."

codeChanges := analyzer.ExtractCodeChanges(text)

assert.Len(t, codeChanges, 4)
}

// / Test for ExtractCodeChanges with standard input
func TestExtractCodeChanges(t *testing.T) {
setup(t)
Expand Down
2 changes: 1 addition & 1 deletion config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ var defaultConfig = Config{
Stream: true,
EncodingFormat: "float",
Temperature: 0.2,
Threshold: 0.3,
Threshold: 0,
ApiKey: "",
},
}
Expand Down
Loading
Loading