-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path02_baseline.py
32 lines (26 loc) · 1.23 KB
/
02_baseline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import asyncio
import pandas as pd
from prompt_evaluator import PromptEvaluator
from vertexai.generative_models import HarmBlockThreshold, HarmCategory
if __name__ == "__main__":
df_train = pd.read_csv('test.csv') # Load your training data
target_model_name = "gemini-1.5-flash"
target_model_config = {
"temperature": 0, "max_output_tokens": 1000
}
review_model_name = "gemini-1.5-flash"
review_model_config = {
"temperature": 0, "max_output_tokens": 10
}
safety_settings = {
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
}
review_prompt_template_path = 'review_prompt_template.txt' # Path to the review prompt text file
evaluator = PromptEvaluator(
df_train, target_model_name, target_model_config, review_model_name, review_model_config, safety_settings, review_prompt_template_path
)
prompt = input("Please enter the prompt for evaluation: ")
asyncio.run(evaluator.main(prompt))