Skip to content

Commit

Permalink
Merge pull request #43 from meysamhadeli/test/add-tests-for-analyzer
Browse files Browse the repository at this point in the history
test: add tests for analyzer
  • Loading branch information
meysamhadeli authored Nov 6, 2024
2 parents 676823a + 1ed5545 commit b409ddf
Show file tree
Hide file tree
Showing 6 changed files with 282 additions and 45 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,6 @@ jobs:
working-directory: ./
run: go build -v ./...

# - name: test
# working-directory: ./
# run: go test -v ./...
- name: test
working-directory: ./
run: go test -v ./...
32 changes: 18 additions & 14 deletions code_analyzer/analyzer.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,10 @@ func (analyzer *CodeAnalyzer) GetProjectFiles(rootDir string) ([]models.FileData
}

relativePath, err := filepath.Rel(rootDir, path)
relativePath = strings.ReplaceAll(relativePath, "\\", "/")

// Check if the current directory or file should be skipped based on default ignore patterns
if utils.IsDefaultIgnored(path) {
if utils.IsDefaultIgnored(relativePath) {
// Skip the directory or file
if d.IsDir() {
// If it's a directory, skip the whole directory
Expand All @@ -106,8 +107,8 @@ func (analyzer *CodeAnalyzer) GetProjectFiles(rootDir string) ([]models.FileData
return nil // Skip this file
}

// Read the file content
content, err := ioutil.ReadFile(relativePath)
// Read the file content using the full path
content, err := ioutil.ReadFile(path) // Use full path from WalkDir
if err != nil {
return fmt.Errorf("failed to read file: %s, error: %w", relativePath, err)
}
Expand Down Expand Up @@ -217,7 +218,7 @@ func (analyzer *CodeAnalyzer) ProcessFile(filePath string, sourceCode []byte) []
// ExtractCodeChanges extracts code changes from the given text.
func (analyzer *CodeAnalyzer) ExtractCodeChanges(text string) ([]models.CodeChange, error) {
if text == "" {
return nil, fmt.Errorf("input text is empty")
return nil, nil
}

// Regex patterns
Expand Down Expand Up @@ -254,24 +255,23 @@ func (analyzer *CodeAnalyzer) ExtractCodeChanges(text string) ([]models.CodeChan
func (analyzer *CodeAnalyzer) TryGetInCompletedCodeBlocK(relativePaths string) (string, error) {
var codes []string

re := regexp.MustCompile(`"files"\s*:\s*\[.*?\]`)
// Simplified regex to capture only the array of files
re := regexp.MustCompile(`\[.*?\]`)
match := re.FindString(relativePaths)

// Wrap with braces to create a valid JSON string
jsonContent := "{" + match + "}"

// Deserialize JSON into a generic map
var result struct {
Files []string `json:"files"`
if match == "" {
return "", fmt.Errorf("no file paths found in input")
}

err := json.Unmarshal([]byte(jsonContent), &result)
// Parse the match into a slice of strings
var filePaths []string
err := json.Unmarshal([]byte(match), &filePaths)
if err != nil {
return "", nil
return "", fmt.Errorf("failed to unmarshal JSON: %v", err)
}

// Loop through each relative path and read the file content
for _, relativePath := range result.Files {
for _, relativePath := range filePaths {
content, err := os.ReadFile(relativePath)
if err != nil {
continue
Expand All @@ -280,6 +280,10 @@ func (analyzer *CodeAnalyzer) TryGetInCompletedCodeBlocK(relativePaths string) (
codes = append(codes, fmt.Sprintf("File: %s\n\n%s", relativePath, content))
}

if len(codes) == 0 {
return "", fmt.Errorf("no valid files read")
}

requestedContext := strings.Join(codes, "\n---------\n\n")

return requestedContext, nil
Expand Down
239 changes: 239 additions & 0 deletions code_analyzer/analyzer_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
package code_analyzer

import (
"fmt"
"github.com/meysamhadeli/codai/code_analyzer/contracts"
"os"
"path/filepath"
"strings"
"testing"

"github.com/stretchr/testify/assert"
)

// Global variables to store the relative test directory and analyzer
var (
relativePathTestDir string
analyzer contracts.ICodeAnalyzer
)

// setup initializes the relative test directory for all tests
func setup(t *testing.T) {
rootDir, err := os.Getwd()
assert.NoError(t, err)

testDir := t.TempDir() // Create a temporary directory
relativePathTestDir, err = filepath.Rel(rootDir, testDir)

if filepath.IsAbs(relativePathTestDir) {
t.Fatalf("relativeTestDir should be relative, but got an absolute path: %s", relativePathTestDir)
}

analyzer = NewCodeAnalyzer(relativePathTestDir)

// Register cleanup to remove everything inside relativePathTestDir
t.Cleanup(func() {
err := os.RemoveAll(relativePathTestDir)
assert.NoError(t, err, "failed to remove test directory")
})
}

// TestMain runs tests sequentially in the specified order
func TestMain(m *testing.M) {
// Setup before running tests
code := m.Run()
// Teardown after running tests (if needed)
os.Exit(code)
}

func TestRunInSequence(t *testing.T) {
setup(t) // setup before the first test runs

// Call tests in order
t.Run("TestGeneratePrompt", TestGeneratePrompt)
t.Run("TestGeneratePrompt_ActualImplementation", TestGeneratePrompt_ActualImplementation)
t.Run("TestNewCodeAnalyzer", TestNewCodeAnalyzer)
t.Run("TestApplyChanges", TestApplyChanges)
t.Run("TestGetProjectFiles", TestGetProjectFiles)
t.Run("TestProcessFile", TestProcessFile)
t.Run("TestExtractCodeChanges", TestExtractCodeChanges)
t.Run("TestTryGetInCompletedCodeBlock", TestTryGetInCompletedCodeBlock)
t.Run("TestTryGetInCompletedCodeBlockWithAdditionalCharacters", TestTryGetInCompletedCodeBlockWithAdditionalsCharacters)
}

func TestGeneratePrompt(t *testing.T) {
// Call the setup function to initialize the test environment
setup(t)

codes := []string{"code1", "code2"}
history := []string{"prev1", "prev2"}
requestedContext := "Requested context"
userInput := "User request"

finalPrompt, userInputPrompt := analyzer.GeneratePrompt(codes, history, userInput, requestedContext)

// Assert that the outputs contain the expected mocked strings
assert.Contains(t, finalPrompt, "code1")
assert.Contains(t, finalPrompt, "code2")
assert.Contains(t, finalPrompt, "prev1")
assert.Contains(t, finalPrompt, "prev2")
assert.Contains(t, finalPrompt, "Requested context")
assert.Contains(t, userInputPrompt, "User request")
}

func TestGeneratePrompt_ActualImplementation(t *testing.T) {
setup(t)

// Assuming boxStyle.Render and embed_data.CodeBlockTemplate are set up correctly
codes := []string{"code1", "code2"}
history := []string{"prev1", "prev2"}
userInput := "User request"
requestedContext := "Requested context"

finalPrompt, userInputPrompt := analyzer.GeneratePrompt(codes, history, userInput, requestedContext)

// Check the content of the actual prompts here
// This will depend on how you set up boxStyle and embed_data
assert.NotEmpty(t, finalPrompt)
assert.NotEmpty(t, userInputPrompt)
}

// Test for NewCodeAnalyzer
func TestNewCodeAnalyzer(t *testing.T) {
setup(t)

assert.NotNil(t, analyzer)
}

// Test for ApplyChanges
func TestApplyChanges(t *testing.T) {
setup(t)
testFilePath := filepath.Join(relativePathTestDir, "test.txt")

// Create a temporary file for testing
_ = os.WriteFile(testFilePath, []byte("test content"), 0644)
_ = os.WriteFile(testFilePath+".tmp", []byte("test content"), 0644)

err := analyzer.ApplyChanges(testFilePath)
assert.NoError(t, err)

content, err := os.ReadFile(testFilePath)
assert.NoError(t, err)
assert.Equal(t, "test content", string(content))
}

// Test for GetProjectFiles
func TestGetProjectFiles(t *testing.T) {
setup(t)

testFilePath := filepath.Join(relativePathTestDir, "test.go")
ignoreFilePath := filepath.Join(relativePathTestDir, ".gitignore")

_ = os.WriteFile(testFilePath, []byte("package main\nfunc main() {}"), 0644)
_ = os.WriteFile(ignoreFilePath, []byte("node_modules\n"), 0644)

files, codes, err := analyzer.GetProjectFiles(relativePathTestDir)

assert.NoError(t, err)
assert.Len(t, files, 1)
assert.Len(t, codes, 1)

for _, file := range files {
assert.NotEmpty(t, file.RelativePath)
assert.Equal(t, "test.go", filepath.Base(file.RelativePath))
}
}

// Test for ProcessFile
func TestProcessFile(t *testing.T) {
setup(t)
content := []byte("class Test {}")

result := analyzer.ProcessFile("test.cs", content)

assert.Contains(t, result, "test.cs")
assert.NotEmpty(t, result)
}

// Test for ExtractCodeChanges
func TestExtractCodeChanges(t *testing.T) {
setup(t)
text := "File: test.go\n```go\npackage main\n```\nFile: test2.go\n```go\npackage main\n```"

codeChanges, err := analyzer.ExtractCodeChanges(text)

assert.NoError(t, err)
assert.Len(t, codeChanges, 2)
assert.Equal(t, "test.go", codeChanges[0].RelativePath)
assert.Equal(t, "package main", codeChanges[0].Code)
}

func TestExtractCodeChangesWithAdditionalsCharacters(t *testing.T) {
setup(t)
text := "\n\n#####File: test.go#####\n```go\npackage main\n```\nFile: test2.go\n```go\npackage main\n```"

codeChanges, err := analyzer.ExtractCodeChanges(text)

assert.NoError(t, err)
assert.Len(t, codeChanges, 2)
assert.Equal(t, "test.go", codeChanges[0].RelativePath)
assert.Equal(t, "package main", codeChanges[0].Code)
}

func TestExtractCodeChangesWithRemoveCharacters(t *testing.T) {
setup(t)
text := "file:test.go\n```go\npackage main\n```\nFile: test2.go\n```go\npackage main\n```"

codeChanges, err := analyzer.ExtractCodeChanges(text)

assert.NoError(t, err)
assert.Len(t, codeChanges, 2)
assert.Equal(t, "test.go", codeChanges[0].RelativePath)
assert.Equal(t, "package main", codeChanges[0].Code)
}

// Test for TryGetInCompletedCodeBlock
func TestTryGetInCompletedCodeBlock(t *testing.T) {
setup(t) // setup before the first test runs

// Create relative paths for test files within the temporary directory
file1Path := strings.ReplaceAll(filepath.Join(relativePathTestDir, "test.go"), `\`, `\\`)
file2Path := strings.ReplaceAll(filepath.Join(relativePathTestDir, "test2.go"), `\`, `\\`)

_ = os.WriteFile(file1Path, []byte("package main\nfunc main() {}"), 0644)
_ = os.WriteFile(file2Path, []byte("package test\nfunc test() {}"), 0644)

// Prepare JSON-encoded relativePaths string with escaped backslashes
relativePaths := fmt.Sprintf(`["%s", "%s"]`, file1Path, file2Path)

requestedContext, err := analyzer.TryGetInCompletedCodeBlocK(relativePaths)

// Assertions
assert.NoError(t, err)
assert.NotEmpty(t, requestedContext)
assert.Contains(t, requestedContext, "package main\nfunc main() {}")
assert.Contains(t, requestedContext, "package test\nfunc test() {}")
}

// Test for TryGetInCompletedCodeBlock
func TestTryGetInCompletedCodeBlockWithAdditionalsCharacters(t *testing.T) {
setup(t) // setup before the first test runs

// Create relative paths for test files within the temporary directory
file1Path := strings.ReplaceAll(filepath.Join(relativePathTestDir, "test.go"), `\`, `\\`)
file2Path := strings.ReplaceAll(filepath.Join(relativePathTestDir, "test2.go"), `\`, `\\`)

_ = os.WriteFile(file1Path, []byte("package main\nfunc main() {}"), 0644)
_ = os.WriteFile(file2Path, []byte("package test\nfunc test() {}"), 0644)

// Prepare JSON-encoded relativePaths string with escaped backslashes
relativePaths := fmt.Sprintf(`{"###file":["%s", "%s"]\n\n}`, file1Path, file2Path)

requestedContext, err := analyzer.TryGetInCompletedCodeBlocK(relativePaths)

// Assertions
assert.NoError(t, err)
assert.NotEmpty(t, requestedContext)
assert.Contains(t, requestedContext, "package main\nfunc main() {}")
assert.Contains(t, requestedContext, "package test\nfunc test() {}")
}
27 changes: 11 additions & 16 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ var defaultConfig = Config{
ChatCompletionModel: "deepseek-coder-v2",
EmbeddingModel: "nomic-embed-text",
Stream: true,
MaxTokens: 128000,
EncodingFormat: "float",
Temperature: 0.2,
Threshold: 0,
Expand All @@ -56,7 +55,6 @@ func LoadConfigs(rootCmd *cobra.Command, cwd string) *Config {
viper.SetDefault("ai_provider_config.encoding_format", defaultConfig.AIProviderConfig.EncodingFormat)
viper.SetDefault("ai_provider_config.temperature", defaultConfig.AIProviderConfig.Temperature)
viper.SetDefault("ai_provider_config.threshold", defaultConfig.AIProviderConfig.Threshold)
viper.SetDefault("ai_provider_config.max_tokens", defaultConfig.AIProviderConfig.MaxTokens)
viper.SetDefault("ai_provider_config.api_key", defaultConfig.AIProviderConfig.ApiKey)

// Automatically read environment variables
Expand Down Expand Up @@ -102,18 +100,17 @@ func LoadConfigs(rootCmd *cobra.Command, cwd string) *Config {

// bindEnv explicitly binds environment variables to configuration keys
func bindEnv() {
viper.BindEnv("theme")
viper.BindEnv("rag")
viper.BindEnv("ai_provider_config.provider_name")
viper.BindEnv("ai_provider_config.embedding_url")
viper.BindEnv("ai_provider_config.chat_completion_url")
viper.BindEnv("ai_provider_config.chat_completion_model")
viper.BindEnv("ai_provider_config.embedding_model")
viper.BindEnv("ai_provider_config.encoding_format")
viper.BindEnv("ai_provider_config.temperature")
viper.BindEnv("ai_provider_config.threshold")
viper.BindEnv("ai_provider_config.max_tokens")
viper.BindEnv("ai_provider_config.api_key")
_ = viper.BindEnv("theme", "THEME")
_ = viper.BindEnv("rag", "RAG")
_ = viper.BindEnv("ai_provider_config.provider_name", "PROVIDER_NAME")
_ = viper.BindEnv("ai_provider_config.embedding_url", "EMBEDDING_URL")
_ = viper.BindEnv("ai_provider_config.chat_completion_url", "CHAT_COMPLETION_URL")
_ = viper.BindEnv("ai_provider_config.chat_completion_model", "CHAT_COMPLETION_MODEL")
_ = viper.BindEnv("ai_provider_config.embedding_model", "EMBEDDING_MODEL")
_ = viper.BindEnv("ai_provider_config.encoding_format", "ENCODING_FORMAT")
_ = viper.BindEnv("ai_provider_config.temperature", "TEMPERATURE")
_ = viper.BindEnv("ai_provider_config.threshold", "THRESHOLD")
_ = viper.BindEnv("ai_provider_config.api_key", "API_KEY")
}

// bindFlags binds the CLI flags to configuration values.
Expand All @@ -128,7 +125,6 @@ func bindFlags(rootCmd *cobra.Command) {
_ = viper.BindPFlag("ai_provider_config.encoding_format", rootCmd.Flags().Lookup("encoding_format"))
_ = viper.BindPFlag("ai_provider_config.temperature", rootCmd.Flags().Lookup("temperature"))
_ = viper.BindPFlag("ai_provider_config.threshold", rootCmd.Flags().Lookup("threshold"))
_ = viper.BindPFlag("ai_provider_config.max_tokens", rootCmd.Flags().Lookup("max_tokens"))
_ = viper.BindPFlag("ai_provider_config.api_key", rootCmd.Flags().Lookup("api_key"))
}

Expand All @@ -147,6 +143,5 @@ func InitFlags(rootCmd *cobra.Command) {
rootCmd.PersistentFlags().String("encoding_format", defaultConfig.AIProviderConfig.EncodingFormat, "Specifies the format in which the AI embeddings or outputs are encoded (e.g., 'float' for floating-point numbers).")
rootCmd.PersistentFlags().Float32("temperature", defaultConfig.AIProviderConfig.Temperature, "Adjusts the AI model’s creativity by setting a temperature value. Higher values result in more creative or varied responses, while lower values make them more focused.")
rootCmd.PersistentFlags().Float64("threshold", defaultConfig.AIProviderConfig.Threshold, "Sets the threshold for similarity calculations in AI systems (e.g., for retrieving related data in a RAG system). Higher values will require closer matches.")
rootCmd.PersistentFlags().Int("max_tokens", defaultConfig.AIProviderConfig.MaxTokens, "Specifies the Maximum number of token can be used by AI model in request.")
rootCmd.PersistentFlags().String("api_key", defaultConfig.AIProviderConfig.ApiKey, "The API key used to authenticate with the AI service provider.")
}
Loading

0 comments on commit b409ddf

Please sign in to comment.