diff --git a/README.md b/README.md index 3392c5a..5b02eac 100644 --- a/README.md +++ b/README.md @@ -50,13 +50,12 @@ $env:API_KEY="your_api_key"" `codai` requires a `config.yml` file in the root of your working directory to analyze your project. By default, the `config.yml` contains the following values: ```yml ai_provider_config: - provider_name: "openai" + provider_name: "openai" chat_completion_url: "http://localhost:11434/v1/chat/completions" chat_completion_model: "gpt-4o" embedding_url: "http://localhost:11434/v1/embeddings" (Optional, If you want use RAG.) embedding_model: "text-embedding-3-small" (Optional, If you want use RAG.) temperature: 0.2 - threshold: 0.3 theme: "dracula" rag: true (Optional, If you want use RAG.) ``` diff --git a/assets/codai-demo.gif b/assets/codai-demo.gif index f1c4945..800491e 100644 Binary files a/assets/codai-demo.gif and b/assets/codai-demo.gif differ diff --git a/cmd/code.go b/cmd/code.go index 1c82aee..4ba65c9 100644 --- a/cmd/code.go +++ b/cmd/code.go @@ -33,18 +33,15 @@ improved responses throughout the user experience.`, func handleCodeCommand(rootDependencies *RootDependencies) { // Create a context with cancel function - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + defer stop() - //Channel to signal when the application should shut down - done := make(chan bool) + spinner := pterm.DefaultSpinner.WithStyle(pterm.NewStyle(pterm.FgLightBlue)).WithSequence("⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏").WithDelay(100).WithRemoveWhenDone(true) - sigs := make(chan os.Signal, 1) + go utils.GracefulShutdown(ctx, stop, func() { - signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) - - go utils.GracefulShutdown(ctx, done, sigs, func() { rootDependencies.ChatHistory.ClearHistory() + rootDependencies.TokenManagement.ClearToken() }) reader := bufio.NewReader(os.Stdin) @@ -58,39 +55,45 @@ func handleCodeCommand(rootDependencies *RootDependencies) { codeOptionsBox := lipgloss.BoxStyle.Render(":help Help for code subcommand") fmt.Println(codeOptionsBox) - spinnerLoadContext, err := pterm.DefaultSpinner.WithStyle(pterm.NewStyle(pterm.FgLightBlue)).WithSequence("⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏").WithDelay(100).Start("Loading Context...") - - // Get all data files from the root directory - fullContextFiles, fullContextCodes, err = rootDependencies.Analyzer.GetProjectFiles(rootDependencies.Cwd) - + spinnerLoadContext, err := spinner.Start("Loading Context...") if err != nil { spinnerLoadContext.Stop() + fmt.Print("\r") fmt.Println(lipgloss.Red.Render(fmt.Sprintf("%v", err))) } + // Get all data files from the root directory + fullContextFiles, fullContextCodes, err = rootDependencies.Analyzer.GetProjectFiles(rootDependencies.Cwd) + spinnerLoadContext.Stop() + fmt.Print("\r") // Launch the user input handler in a goroutine startLoop: // Label for the start loop for { select { case <-ctx.Done(): - <-done // Wait for gracefulShutdown to complete + // Wait for GracefulShutdown to complete return default: - displayTokens := func() { rootDependencies.TokenManagement.DisplayTokens(rootDependencies.Config.AIProviderConfig.ProviderName, rootDependencies.Config.AIProviderConfig.ChatCompletionModel, rootDependencies.Config.AIProviderConfig.EmbeddingModel, rootDependencies.Config.RAG) } // Get user input userInput, err := utils.InputPrompt(reader) + if err != nil { fmt.Println(lipgloss.Red.Render(fmt.Sprintf("%v", err))) continue } + if userInput == "" { + fmt.Print("\r") + continue + } + // Configure help code subcommand isHelpSubcommands, exit := findCodeSubCommand(userInput, rootDependencies) @@ -99,18 +102,17 @@ startLoop: // Label for the start loop } if exit { - cancel() // Initiate shutdown for the app's own ":exit" command - <-done // Wait for gracefulShutdown to complete return } // If RAG is enabled, we use RAG system for retrieve most relevant data due user request if rootDependencies.Config.RAG { + + spinnerEmbeddingContext, err := spinner.Start("Embedding Context...") + var wg sync.WaitGroup errorChan := make(chan error, len(fullContextFiles)) - spinnerLoadContextEmbedding, err := pterm.DefaultSpinner.WithStyle(pterm.NewStyle(pterm.FgLightBlue)).WithSequence("⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏").WithDelay(100).Start("Embedding Context...") - for _, dataFile := range fullContextFiles { wg.Add(1) // Increment the WaitGroup counter go func(dataFile models.FileData) { @@ -127,7 +129,7 @@ startLoop: // Label for the start loop } // Call the retryWithBackoff function with the operation and a 3-time retry - if err := utils.RetryWithBackoff(filesEmbeddingOperation, 3); err != nil { + if err := filesEmbeddingOperation(); err != nil { errorChan <- err // Send the error to the channel } }(dataFile) // Pass the current dataFile to the Goroutine @@ -137,7 +139,8 @@ startLoop: // Label for the start loop close(errorChan) // Close the error channel // Handle any errors that occurred during processing for err = range errorChan { - spinnerLoadContextEmbedding.Stop() + spinnerEmbeddingContext.Stop() + fmt.Print("\r") fmt.Println(lipgloss.Red.Render(fmt.Sprintf("%v", err))) displayTokens() continue startLoop @@ -159,21 +162,20 @@ startLoop: // Label for the start loop topN := -1 // Step 6: Find relevant code chunks based on the user query embedding - fullContextCodes = rootDependencies.Store.FindRelevantChunks(queryEmbedding[0], topN, rootDependencies.Config.AIProviderConfig.Threshold) + fullContextCodes = rootDependencies.Store.FindRelevantChunks(queryEmbedding[0], topN, rootDependencies.Config.AIProviderConfig.EmbeddingModel, rootDependencies.Config.AIProviderConfig.Threshold) return nil } - // Call the retryWithBackoff function with the operation and a 3 time retry - err = utils.RetryWithBackoff(queryEmbeddingOperation, 3) - - if err != nil { - spinnerLoadContextEmbedding.Stop() + if err := queryEmbeddingOperation(); err != nil { + spinnerEmbeddingContext.Stop() + fmt.Print("\r") fmt.Println(lipgloss.Red.Render(fmt.Sprintf("%v", err))) displayTokens() continue startLoop } - spinnerLoadContextEmbedding.Stop() + spinnerEmbeddingContext.Stop() + fmt.Print("\r") } var aiResponseBuilder strings.Builder @@ -206,10 +208,7 @@ startLoop: // Label for the start loop return nil } - // Call the retryWithBackoff function with the operation and a 3 time retry - err = utils.RetryWithBackoff(chatRequestOperation, 3) - - if err != nil { + if err := chatRequestOperation(); err != nil { fmt.Println(lipgloss.Red.Render(fmt.Sprintf("%v", err))) displayTokens() continue startLoop @@ -224,9 +223,7 @@ startLoop: // Label for the start loop fmt.Println(lipgloss.BlueSky.Render("\nThese files need to changes...\n")) - err = chatRequestOperation() - - if err != nil { + if err := chatRequestOperation(); err != nil { fmt.Println(lipgloss.Red.Render(fmt.Sprintf("%v", err))) displayTokens() continue @@ -275,14 +272,16 @@ startLoop: // Label for the start loop // If we need Update the context after apply changes if updateContextNeeded { - spinnerUpdateContext, err := pterm.DefaultSpinner.WithStyle(pterm.NewStyle(pterm.FgLightBlue)).WithSequence("⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏").WithDelay(100).Start("Updating Context...") + spinnerUpdateContext, err := spinner.Start("Updating Context...") fullContextFiles, fullContextCodes, err = rootDependencies.Analyzer.GetProjectFiles(rootDependencies.Cwd) if err != nil { spinnerUpdateContext.Stop() + fmt.Print("\r") fmt.Println(lipgloss.Red.Render(fmt.Sprintf("%v", err))) } spinnerUpdateContext.Stop() + fmt.Print("\r") } displayTokens() } @@ -300,7 +299,7 @@ func findCodeSubCommand(command string, rootDependencies *RootDependencies) (boo fmt.Print("\033[2J\033[H") return true, false case ":exit": - return false, false + return false, true case ":token": rootDependencies.TokenManagement.DisplayTokens( rootDependencies.Config.AIProviderConfig.ProviderName, diff --git a/code_analyzer/analyzer.go b/code_analyzer/analyzer.go index e276ebb..8f0051b 100644 --- a/code_analyzer/analyzer.go +++ b/code_analyzer/analyzer.go @@ -91,9 +91,16 @@ func (analyzer *CodeAnalyzer) GetProjectFiles(rootDir string) ([]models.FileData // Ensure that the current entry is a file, not a directory if !d.IsDir() { + // Check file size + fileInfo, err := os.Stat(path) if err != nil { - return err + return fmt.Errorf("failed to get file info: %s, error: %w", relativePath, err) } + // Skip files over 100 KB (100 * 1024 bytes) + if fileInfo.Size() > 100*1024 { + return nil // Skip this file + } + if utils.IsGitIgnored(relativePath, gitIgnorePatterns) { // Debugging: Print the ignored file return nil // Skip this file @@ -245,7 +252,7 @@ func (analyzer *CodeAnalyzer) TryGetInCompletedCodeBlocK(relativePaths string) ( } func (analyzer *CodeAnalyzer) ExtractCodeChanges(diff string) []models.CodeChange { - filePathPattern := regexp.MustCompile(`(?i)(?:\d+\.\s*|File:\s*)(\S+\.[a-zA-Z0-9]+)`) + filePathPattern := regexp.MustCompile("(?i)(?:\\d+\\.\\s*|File:\\s*)[`']?([^\\s*`']+?\\.[a-zA-Z0-9]+)[`']?\\b") lines := strings.Split(diff, "\n") var fileChanges []models.CodeChange @@ -333,7 +340,6 @@ func (analyzer *CodeAnalyzer) ApplyChanges(relativePath, diff string) error { } else if strings.HasPrefix(trimmedLine, "+") { // Add lines that start with "+", but remove the "+" symbol updatedContent = append(updatedContent, strings.ReplaceAll(trimmedLine, "+", " ")) - } else { // Keep all other lines as they are updatedContent = append(updatedContent, line) diff --git a/code_analyzer/analyzer_test.go b/code_analyzer/analyzer_test.go index c3a35cf..89e2aa5 100644 --- a/code_analyzer/analyzer_test.go +++ b/code_analyzer/analyzer_test.go @@ -218,16 +218,6 @@ func TestExtractCodeChangesComplexText(t *testing.T) { assert.Equal(t, "import pygame\nimport random\n\n# Initialize pygame\npygame.init()\n\n# Screen dimensions\nSCREEN_WIDTH = 800\nSCREEN_HEIGHT = 600\n\n# Colors\nBLACK = (0, 0, 0)\nWHITE = (255, 255, 255)\nYELLOW = (255, 255, 0)\nRED = (255, 0, 0)\n\n# Pacman settings\nPACMAN_SIZE = 50\nPACMAN_SPEED = 5\n\n# Ghost settings\nGHOST_SIZE = 50\nGHOST_SPEED = 3\n\n# Create the screen\nscreen = pygame.display.set_mode((SCREEN_WIDTH, SCREEN_HEIGHT))\npygame.display.set_caption(\"Pacman Game\")\n\n# Load images\npacman_image = pygame.image.load(\"pacman.png\")\npacman_image = pygame.transform.scale(pacman_image, (PACMAN_SIZE, PACMAN_SIZE))\n\nghost_image = pygame.image.load(\"ghost.png\")\nghost_image = pygame.transform.scale(ghost_image, (GHOST_SIZE, GHOST_SIZE))\n\n# Pacman class\nclass Pacman:\n def __init__(self):\n self.x = SCREEN_WIDTH // 2\n self.y = SCREEN_HEIGHT // 2\n self.speed = PACMAN_SPEED\n self.image = pacman_image\n\n def move(self, dx, dy):\n self.x += dx * self.speed\n self.y += dy * self.speed\n\n # Boundary check\n if self.x < 0:\n self.x = 0\n elif self.x > SCREEN_WIDTH - PACMAN_SIZE:\n self.x = SCREEN_WIDTH - PACMAN_SIZE\n\n if self.y < 0:\n self.y = 0\n elif self.y > SCREEN_HEIGHT - PACMAN_SIZE:\n self.y = SCREEN_HEIGHT - PACMAN_SIZE\n\n def draw(self):\n screen.blit(self.image, (self.x, self.y))\n\n# Ghost class\nclass Ghost:\n def __init__(self):\n self.x = random.randint(0, SCREEN_WIDTH - GHOST_SIZE)\n self.y = random.randint(0, SCREEN_HEIGHT - GHOST_SIZE)\n self.speed = GHOST_SPEED\n self.image = ghost_image\n\n def move(self):\n self.x += random.choice([-1, 1]) * self.speed\n self.y += random.choice([-1, 1]) * self.speed\n\n # Boundary check\n if self.x < 0:\n self.x = 0\n elif self.x > SCREEN_WIDTH - GHOST_SIZE:\n self.x = SCREEN_WIDTH - GHOST_SIZE\n\n if self.y < 0:\n self.y = 0\n elif self.y > SCREEN_HEIGHT - GHOST_SIZE:\n self.y = SCREEN_HEIGHT - GHOST_SIZE\n\n def draw(self):\n screen.blit(self.image, (self.x, self.y))\n\n# Main game loop\ndef main():\n clock = pygame.time.Clock()\n pacman = Pacman()\n ghosts = [Ghost() for _ in range(4)]\n\n running = True\n while running:\n for event in pygame.event.get():\n if event.type == pygame.QUIT:\n running = False\n\n keys = pygame.key.get_pressed()\n dx = dy = 0\n if keys[pygame.K_LEFT]:\n dx = -1\n if keys[pygame.K_RIGHT]:\n dx = 1\n if keys[pygame.K_UP]:\n dy = -1\n if keys[pygame.K_DOWN]:\n dy = 1\n\n pacman.move(dx, dy)\n\n for ghost in ghosts:\n ghost.move()\n\n screen.fill(BLACK)\n pacman.draw()\n for ghost in ghosts:\n ghost.draw()\n\n pygame.display.flip()\n clock.tick(30)\n\n pygame.quit()\n\nif __name__ == \"__main__\":\n main()", codeChanges[0].Code) } -// / Test for ExtractCodeChanges with standard input -func TestExtractCodeChangesComplexText2(t *testing.T) { - setup(t) - text := "Sure, I can help you create a simple \"Dino Game\" using Python and the Pygame library. This game will be similar to the Chrome Dino game where a dinosaur jumps over obstacles.\n\nFirst, ensure you have Pygame installed. You can install it using pip:\n```sh\npip install pygame\n```\n\nNow, let's create the game. We will need a few files:\n1. `main.py` - The main game logic.\n2. `dino.py` - The Dino character logic.\n3. `obstacle.py` - The obstacle logic.\n4. `settings.py` - Game settings and configurations.\n\n### File: `settings.py`\n```python\n# Game settings and configurations\n\nSCREEN_WIDTH = 800\nSCREEN_HEIGHT = 400\nFPS = 30\n\n# Colors\nWHITE = (255, 255, 255)\nBLACK = (0, 0, 0)\n\n# Dino settings\nDINO_WIDTH = 40\nDINO_HEIGHT = 60\nDINO_COLOR = (0, 255, 0)\nDINO_JUMP_HEIGHT = 10\n\n# Obstacle settings\nOBSTACLE_WIDTH = 20\nOBSTACLE_HEIGHT = 40\nOBSTACLE_COLOR = (255, 0, 0)\nOBSTACLE_SPEED = 5\n```\n\n### File: `dino.py`\n```python\nimport pygame\nfrom settings import DINO_WIDTH, DINO_HEIGHT, DINO_COLOR, DINO_JUMP_HEIGHT\n\nclass Dino:\n def __init__(self, x, y):\n self.rect = pygame.Rect(x, y, DINO_WIDTH, DINO_HEIGHT)\n self.color = DINO_COLOR\n self.is_jumping = False\n self.jump_speed = DINO_JUMP_HEIGHT\n self.gravity = 1\n\n def update(self):\n if self.is_jumping:\n self.rect.y -= self.jump_speed\n self.jump_speed -= self.gravity\n if self.jump_speed < -DINO_JUMP_HEIGHT:\n self.is_jumping = False\n self.jump_speed = DINO_JUMP_HEIGHT\n\n def draw(self, screen):\n pygame.draw.rect(screen, self.color, self.rect)\n\n def jump(self):\n if not self.is_jumping:\n self.is_jumping = True\n```\n\n### File: `obstacle.py`\n```python\nimport pygame\nfrom settings import OBSTACLE_WIDTH, OBSTACLE_HEIGHT, OBSTACLE_COLOR, OBSTACLE_SPEED\n\nclass Obstacle:\n def __init__(self, x, y):\n self.rect = pygame.Rect(x, y, OBSTACLE_WIDTH, OBSTACLE_HEIGHT)\n self.color = OBSTACLE_COLOR\n self.speed = OBSTACLE_SPEED\n\n def update(self):\n self.rect.x -= self.speed\n\n def draw(self, screen):\n pygame.draw.rect(screen, self.color, self.rect)\n```\n\n### File: `main.py`\n```python\nimport pygame\nimport random\nfrom settings import SCREEN_WIDTH, SCREEN_HEIGHT, FPS, WHITE\nfrom dino.py import Dino\nfrom obstacle.py import Obstacle\n\ndef main():\n pygame.init()\n screen = pygame.display.set_mode((SCREEN_WIDTH, SCREEN_HEIGHT))\n pygame.display.set_caption(\"Dino Game\")\n clock = pygame.time.Clock()\n\n dino = Dino(50, SCREEN_HEIGHT - 60)\n obstacles = []\n\n running = True\n while running:\n for event in pygame.event.get():\n if event.type == pygame.QUIT:\n running = False\n if event.type == pygame.KEYDOWN:\n if event.key == pygame.K_SPACE:\n dino.jump()\n\n screen.fill(WHITE)\n\n dino.update()\n dino.draw(screen)\n\n if random.randint(1, 100) < 2:\n obstacles.append(Obstacle(SCREEN_WIDTH, SCREEN_HEIGHT - 40))\n\n for obstacle in obstacles[:]:\n obstacle.update()\n obstacle.draw(screen)\n if obstacle.rect.x < 0:\n obstacles.remove(obstacle)\n if dino.rect.colliderect(obstacle.rect):\n running = False\n\n pygame.display.flip()\n clock.tick(FPS)\n\n pygame.quit()\n\nif __name__ == \"__main__\":\n main()\n```\n\nThis code sets up a basic Dino game where the dinosaur can jump over obstacles. The game will end if the dinosaur collides with an obstacle. You can expand and improve this game by adding more features, such as scoring, different types of obstacles, and animations." - - codeChanges := analyzer.ExtractCodeChanges(text) - - assert.Len(t, codeChanges, 4) -} - // / Test for ExtractCodeChanges with standard input func TestExtractCodeChanges(t *testing.T) { setup(t) diff --git a/config/config.go b/config/config.go index b6eed0b..6117611 100644 --- a/config/config.go +++ b/config/config.go @@ -31,7 +31,7 @@ var defaultConfig = Config{ Stream: true, EncodingFormat: "float", Temperature: 0.2, - Threshold: 0.3, + Threshold: 0, ApiKey: "", }, } diff --git a/embed_data/prompts/rag_context_prompt.tmpl b/embed_data/prompts/rag_context_prompt.tmpl index 3b1aae1..7e55340 100644 --- a/embed_data/prompts/rag_context_prompt.tmpl +++ b/embed_data/prompts/rag_context_prompt.tmpl @@ -13,50 +13,35 @@ ## General Instructions for Code Modifications: - - If you **modify** or **add** lines in code, follow the exact format for the **Code BLOCK**: - **First line**: the **file name** with **relative path**; no extra markup, punctuation, comments, etc. **JUST** the **file name** with **relative path** and **file name** should using **naming conversion** base on **language**. - - **Second line**: Start of the markdown highlighted **Code BLOCK**. + - **Second line**: Start of the **diff BLOCK**. - **All subsequent lines**: The modified code, with specific prefixes based on the change: - Prefix **"+"** for added lines. - Prefix **"-"** for removed lines. - **Unchanged lines** should **not** be prefixed; they should remain as they are. - - **Last line**: End of the markdown highlighted **Code BLOCK**. - - If you provide a **diff** always give me **full code** and **do not summarize that** only for **modify part**, always I want **full code**. - - Always add **relative path** and **file name** **top** of each **Code BLOCK**. + - **Last line**: End of the **diff BLOCK**. + - Always add **relative path** and **file name** **top** of each **diff BLOCK**. + - If **add** a new line, always **must** use prefix **"+"**. + - If **remove** a line, always **must** use prefix **"-"**. + - If you **modify** or **fix** or **refactor** each line **must** use **"+"** or **"-"** and for **unchanged line** leave it as it is. + - **Do not** forget put prefix **"+"** and **"-"** during **modification** lines. - -## Code BLOCK Format: - - Every code modification **must follow** this pattern strictly, ensuring that added, removed, and unchanged lines are marked correctly with **"+"**, **"-"**, or skipped for unchanged lines and keep in mind **give me all of unchanged line and dont summarize them.** - - **Example format**: +## **diff BLOCK** Code Format: File: relativePath/fileName.ext -```diff - package main +```if chane is a modification use **diff**, otherwise if it's a **new file** use **language of code** here. + package main - import "fmt" + import "fmt" + import "time" - func main() { + func main() { - fmt.Println("Hello, World!") + fmt.Println("Welcome to Go programming!") + fmt.Println("Current time:", time.Now()) - fmt.Println("This is another unchanged line") - } + fmt.Println("This is another unchanged line") + } ``` - - The **added lines** would have the **"+"** prefix. - - The **removed lines** would have the **"-"** prefix. - - The **unchanged lines** should **remain as is** without any prefix. - -## For Added Lines: - - Any new lines or methods that are added **must be prefixed with a plus sign (+)**. - - If you add methods or classes, mark them as added. - -## For Removed Lines: - - Any lines that are deleted **must be prefixed with a minus sign (-)**. - -## For Unchanged Lines: - - Always give me unchanged line **without summarize** them **exactly like as they were before**. - - Any line of code that remains unchanged **should not be prefixed and should remain as is**. ## Important: - - Under no circumstances, if the Code BLOCK is empty or Code BLOCK is incomplete, do **not** include placeholder comments like "// REST OF THE CODE" or "// IMPLEMENTATION OF....". \ No newline at end of file + - Under no circumstances, if the some part of **body** or **block** is **empty** or **incomplete**, do **not** include placeholder comments like "// REST OF THE CODE" or "// IMPLEMENTATION OF....". \ No newline at end of file diff --git a/embed_data/prompts/summarize_full_context_prompt.tmpl b/embed_data/prompts/summarize_full_context_prompt.tmpl index bdd27c5..b03044b 100644 --- a/embed_data/prompts/summarize_full_context_prompt.tmpl +++ b/embed_data/prompts/summarize_full_context_prompt.tmpl @@ -33,50 +33,37 @@ - **Skip all further instructions**, including any additional processing or explanation. + ## General Instructions for Code Modifications: - - If you **modify** or **add** lines in code, follow the exact format for the **Code BLOCK**: - **First line**: the **file name** with **relative path**; no extra markup, punctuation, comments, etc. **JUST** the **file name** with **relative path** and **file name** should using **naming conversion** base on **language**. - - **Second line**: Start of the markdown highlighted **Code BLOCK**. + - **Second line**: Start of the **diff BLOCK**. - **All subsequent lines**: The modified code, with specific prefixes based on the change: - Prefix **"+"** for added lines. - Prefix **"-"** for removed lines. - **Unchanged lines** should **not** be prefixed; they should remain as they are. - - **Last line**: End of the markdown highlighted **Code BLOCK**. - - If you provide a **diff** always give me **full code** and **do not summarize that** only for **modify part**, always I want **full code**. - - Always add **relative path** and **file name** **top** of each **Code BLOCK**. + - **Last line**: End of the **diff BLOCK**. + - Always add **relative path** and **file name** **top** of each **diff BLOCK**. + - If **add** a new line, always **must** use prefix **"+"**. + - If **remove** a line, always **must** use prefix **"-"**. + - If you **modify** or **fix** or **refactor** each line **must** use **"+"** or **"-"** and for **unchanged line** leave it as it is. + - **Do not** forget put prefix **"+"** and **"-"** during **modification** lines. -## Code BLOCK Format: - - Every code modification **must follow** this pattern strictly, ensuring that added, removed, and unchanged lines are marked correctly with **"+"**, **"-"**, or skipped for unchanged lines and keep in mind **give me all of unchanged line and dont summarize them.** - - **Example format**: +## **diff BLOCK** Code Format: File: relativePath/fileName.ext -```diff - package main +```if chane is a modification use **diff**, otherwise if it's a **new file** use **language of code** here. + package main - import "fmt" + import "fmt" + import "time" - func main() { + func main() { - fmt.Println("Hello, World!") + fmt.Println("Welcome to Go programming!") + fmt.Println("Current time:", time.Now()) - fmt.Println("This is another unchanged line") - } + fmt.Println("This is another unchanged line") + } ``` - - The **added lines** would have the **"+"** prefix. - - The **removed lines** would have the **"-"** prefix. - - The **unchanged lines** should **remain as is** without any prefix. - -## For Added Lines: - - Any new lines or methods that are added **must be prefixed with a plus sign (+)**. - - If you add methods or classes, mark them as added. - -## For Removed Lines: - - Any lines that are deleted **must be prefixed with a minus sign (-)**. - -## For Unchanged Lines: - - Always give me unchanged line **without summarize** them **exactly like as they were before**. - - Any line of code that remains unchanged **should not be prefixed and should remain as is**. ## Important: - - Under no circumstances, if the Code BLOCK is empty or Code BLOCK is incomplete, do **not** include placeholder comments like "// REST OF THE CODE" or "// IMPLEMENTATION OF....". \ No newline at end of file + - Under no circumstances, if the some part of **body** or **block** is **empty** or **incomplete**, do **not** include placeholder comments like "// REST OF THE CODE" or "// IMPLEMENTATION OF....". \ No newline at end of file diff --git a/embedding_store/contracts/embedding_store.go b/embedding_store/contracts/embedding_store.go index ebe89f8..68de3b9 100644 --- a/embedding_store/contracts/embedding_store.go +++ b/embedding_store/contracts/embedding_store.go @@ -3,6 +3,6 @@ package contracts // IEmbeddingStore defines the interface for managing code and embeddings. type IEmbeddingStore interface { Save(key string, code string, embeddings []float64) - FindRelevantChunks(queryEmbedding []float64, topN int, threshold float64) []string + FindRelevantChunks(queryEmbedding []float64, topN int, embeddingModel string, threshold float64) []string CosineSimilarity(vec1, vec2 []float64) float64 } diff --git a/embedding_store/embedding_store.go b/embedding_store/embedding_store.go index 95f16fd..bc293ee 100644 --- a/embedding_store/embedding_store.go +++ b/embedding_store/embedding_store.go @@ -15,6 +15,25 @@ type EmbeddingStore struct { mu sync.RWMutex } +func (store *EmbeddingStore) FindThresholdByModel(modelName string) float64 { + switch modelName { + case "all-minilm:l6-v2": + return 0.22 + case "mxbai-embed-large": + return 0.4 + case "nomic-embed-text": + return 0.4 + case "text-embedding-3-large": + return 0.4 + case "text-embedding-3-small": + return 0.4 + case "text-embedding-ada-002": + return 0.75 + default: + return 0.3 + } +} + // NewEmbeddingStoreModel initializes a new CodeEmbeddingStoreModel. func NewEmbeddingStoreModel() contracts.IEmbeddingStore { return &EmbeddingStore{ @@ -62,7 +81,7 @@ func (store *EmbeddingStore) CosineSimilarity(vec1, vec2 []float64) float64 { } // FindRelevantChunks retrieves the relevant code chunks from the embedding store based on a similarity threshold. -func (store *EmbeddingStore) FindRelevantChunks(queryEmbedding []float64, topN int, threshold float64) []string { +func (store *EmbeddingStore) FindRelevantChunks(queryEmbedding []float64, topN int, embeddingModel string, threshold float64) []string { type similarityResult struct { FileName string Similarity float64 @@ -70,6 +89,10 @@ func (store *EmbeddingStore) FindRelevantChunks(queryEmbedding []float64, topN i var results []similarityResult + if threshold == 0 { + threshold = store.FindThresholdByModel(embeddingModel) + } + // Calculate similarity for each stored embedding for fileName, storedEmbedding := range store.EmbeddingsStore { similarity := store.CosineSimilarity(queryEmbedding, storedEmbedding) diff --git a/go.mod b/go.mod index 7891a09..a02886c 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,6 @@ go 1.23.2 require ( github.com/alecthomas/chroma/v2 v2.14.0 - github.com/cenkalti/backoff/v4 v4.3.0 github.com/charmbracelet/lipgloss v1.0.0 github.com/pterm/pterm v0.12.79 github.com/smacker/go-tree-sitter v0.0.0-20240827094217-dd81d9e9be82 diff --git a/go.sum b/go.sum index 3328d82..49af607 100644 --- a/go.sum +++ b/go.sum @@ -24,8 +24,6 @@ github.com/alecthomas/repr v0.4.0/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW5 github.com/atomicgo/cursor v0.0.1/go.mod h1:cBON2QmmrysudxNBFthvMtN32r3jxVRIvzkUiF/RuIk= github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k= github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8= -github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= -github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/charmbracelet/lipgloss v1.0.0 h1:O7VkGDvqEdGi93X+DeqsQ7PKHDgtQfF8j8/O2qFMQNg= github.com/charmbracelet/lipgloss v1.0.0/go.mod h1:U5fy9Z+C38obMs+T+tJqst9VGzlOYGj4ri9reL3qUlo= github.com/charmbracelet/x/ansi v0.4.2 h1:0JM6Aj/g/KC154/gOP4vfxun0ff6itogDYk41kof+qk= diff --git a/providers/token_management.go b/providers/token_management.go index 0c57a48..c72e391 100644 --- a/providers/token_management.go +++ b/providers/token_management.go @@ -69,10 +69,10 @@ func (tm *tokenManager) DisplayTokens(providerName string, model string, embeddi cost := tm.CalculateCost(providerName, model, tm.usedInputToken, tm.usedOutputToken) costEmbedding := tm.CalculateCost(providerName, embeddingModel, tm.usedEmbeddingInputToken, tm.usedEmbeddingOutputToken) - tokenInfo := fmt.Sprintf("Token Used: '%s' - Cost: '%s'$ - Chat Model: '%s'", fmt.Sprint(tm.usedToken), fmt.Sprintf("%.6f", cost), model) + tokenInfo := fmt.Sprintf("Token Used: %s - Cost: %s $ - Chat Model: %s", fmt.Sprint(tm.usedToken), fmt.Sprintf("%.6f", cost), model) if isRag { - embeddingTokenDetails := fmt.Sprintf("Token Used: '%s' - Cost: '%s'$ - Embedding Model: '%s'", fmt.Sprint(tm.usedEmbeddingToken), fmt.Sprintf("%.6f", costEmbedding), embeddingModel) + embeddingTokenDetails := fmt.Sprintf("Token Used: %s - Cost: %s $ - Embedding Model: %s", fmt.Sprint(tm.usedEmbeddingToken), fmt.Sprintf("%.6f", costEmbedding), embeddingModel) tokenInfo = tokenInfo + "\n" + embeddingTokenDetails } diff --git a/utils/confirm_prompt.go b/utils/confirm_prompt.go index c021082..27a0d33 100644 --- a/utils/confirm_prompt.go +++ b/utils/confirm_prompt.go @@ -13,7 +13,7 @@ func ConfirmPrompt(path string) (bool, error) { reader := bufio.NewReader(os.Stdin) // Styled prompt message - fmt.Printf(lipgloss.BlueSky.Render(fmt.Sprintf("Do you want to accept the change for file '%v'%s", lipgloss.LightBlueB.Render(path), lipgloss.BlueSky.Render(" ? (y/n): ")))) + fmt.Printf(lipgloss.BlueSky.Render(fmt.Sprintf("Do you want to accept the change for file %v%s", lipgloss.LightBlueB.Render(path), lipgloss.BlueSky.Render(" ? (y/n): ")))) for { // Read user input diff --git a/utils/graceful_shutdown.go b/utils/graceful_shutdown.go index 96a712f..ca5d1c5 100644 --- a/utils/graceful_shutdown.go +++ b/utils/graceful_shutdown.go @@ -7,28 +7,22 @@ import ( "os" ) -func GracefulShutdown(ctx context.Context, done chan bool, sigs chan os.Signal, chatHistoryCleanUp func()) { - - go func() { - for { - select { - case <-sigs: - chatHistoryCleanUp() - done <- true // Signal the application to exit - } - } - }() - - // Defer the recovery function to handle panics +func GracefulShutdown(ctx context.Context, stop context.CancelFunc, cleanup func()) { + // Defer the recovery function to handle any panics during cleanup defer func() { if r := recover(); r != nil { - fmt.Println(lipgloss.Red.Render(fmt.Sprintf("recovered from panic: %v", r))) - chatHistoryCleanUp() - done <- true // Signal the application to exit + fmt.Println(lipgloss.Red.Render(fmt.Sprintf("Recovered from panic: %v", r))) + cleanup() } }() + // Wait for the context to be canceled by an external signal (e.g., SIGINT or SIGTERM) <-ctx.Done() - close(done) - return + + stop() // Cancel the context to stop further processing + + // When the context is canceled, perform cleanup + cleanup() + + os.Exit(0) // Exit program gracefully } diff --git a/utils/ignore_files.go b/utils/ignore_files.go index 2e70ed4..80ae3c1 100644 --- a/utils/ignore_files.go +++ b/utils/ignore_files.go @@ -58,6 +58,20 @@ func IsDefaultIgnored(path string) bool { "*.dll", "*.log", "*.bak", + ".mp3", + ".wav", + ".aac", + ".flac", + ".ogg", + ".jpg", + ".jpeg", + ".png", + ".gif", + ".mkv", + ".mp4", + ".avi", + ".mov", + ".wmv", } // Split the path into parts based on the file separator @@ -65,6 +79,7 @@ func IsDefaultIgnored(path string) bool { // Check each part for any ignore patterns for _, part := range parts { + part = strings.ToLower(part) for _, pattern := range ignorePatterns { if strings.HasPrefix(pattern, "*") { // If the pattern starts with '*', check for suffix diff --git a/utils/retry_backoff.go b/utils/retry_backoff.go deleted file mode 100644 index ee46c41..0000000 --- a/utils/retry_backoff.go +++ /dev/null @@ -1,21 +0,0 @@ -package utils - -import ( - "github.com/cenkalti/backoff/v4" - "time" -) - -// RetryWithBackoff Function to encapsulate the retry logic with backoff -func RetryWithBackoff(operation func() error, maxRetries uint64) error { - // Create a new exponential backoff configuration - expBackoff := backoff.NewExponentialBackOff() - - // Set a max interval between retries (optional) - expBackoff.MaxInterval = 5 * time.Second - - // Wrap the backoff with a fixed number of retries - backoffWithRetries := backoff.WithMaxRetries(expBackoff, maxRetries) - - // Retry the operation using the backoff strategy with the retry limit - return backoff.Retry(operation, backoffWithRetries) -}