-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathexperiment.py
250 lines (216 loc) · 8.42 KB
/
experiment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
import openai
import dataclasses
from dotenv import load_dotenv
import argparse
from collections import defaultdict
from tqdm import tqdm
import random
import time
import csv
# define a retry decorator
def retry_with_exponential_backoff(
func,
initial_delay: float = 1,
exponential_base: float = 2,
jitter: bool = True,
max_retries: int = 10,
errors: tuple = (openai.RateLimitError,),
):
"""Retry a function with exponential backoff."""
def wrapper(*args, **kwargs):
# Initialize variables
num_retries = 0
delay = initial_delay
# Loop until a successful response or max_retries is hit or an exception is raised
while True:
try:
return func(*args, **kwargs)
# Retry on specified errors
except errors as e:
# Increment retries
num_retries += 1
# Check if max retries has been reached
if num_retries > max_retries:
raise Exception(
f"Maximum number of retries ({max_retries}) exceeded."
)
# Increment the delay
delay *= exponential_base * (1 + jitter * random.random())
# Sleep for the delay
time.sleep(delay)
# Raise exceptions for any errors not specified
except Exception as e:
raise e
return wrapper
load_dotenv()
client = openai.OpenAI()
@dataclasses.dataclass
class Agent:
system_prompt: str
model: str
name: str
@retry_with_exponential_backoff
def generate_message(
conversation: list[str], agent: Agent, voting_prompt: str = None
) -> str:
messages = [{"role": "system", "content": agent.system_prompt}] + [
{"role": "user", "content": message} for message in conversation
]
if voting_prompt:
messages.append({"role": "user", "content": voting_prompt})
response = client.chat.completions.create(
model=agent.model,
messages=messages,
)
if voting_prompt:
return response.choices[0].message.content
else:
return f"{agent.name}: {response.choices[0].message.content}"
def get_system_prompt(topic: str, stance: str, name: str) -> str:
return f"""You are {name}, having a debate/conversation with two other players on a tri-party debate topic: \"{topic}\".
You are a proponent of {stance}. Your argument should be conversational and in plaintext.
Do not say your own name like "{name}: ".
"""
def get_generic_voting_prompt(stances: list[str]) -> str:
prompt = f"Please vote for one of the following options: {', '.join(stances)}. Do not output anything other than the option you voted for. Your vote is case-sensitive. Do not change the case or formatting of the option you voted for."
return prompt
def eval(args):
tri_party_debate_topic = args.topic
stances = args.stances
names = ["Player 1", "Player 2", "Player 3"]
for i, stance in tqdm(enumerate(stances)):
print("=" * 100)
print(f"Running experiment {i} with topic: {tri_party_debate_topic}")
print("-" * 100)
# Agent 1 (weak model) gets to be each stance at least once
agent1 = Agent(
name=names[0],
system_prompt=get_system_prompt(tri_party_debate_topic, stance, names[0]),
model="gpt-4o-mini",
)
# Set other agents to be the other stances
agent2 = Agent(
name=names[1],
system_prompt=get_system_prompt(
tri_party_debate_topic, stances[(i + 1) % 3], names[1]
),
model="gpt-4o",
)
agent3 = Agent(
name=names[2],
system_prompt=get_system_prompt(
tri_party_debate_topic, stances[(i + 2) % 3], names[2]
),
model="gpt-4o",
)
agent_list = [agent1, agent2, agent3]
print(f"Player 1 ({agent1.model}) stance: {stance}")
print(f"Player 2 ({agent2.model}) stance: {stances[(i + 1) % 3]}")
print(f"Player 3 ({agent3.model}) stance: {stances[(i + 2) % 3]}")
print("-" * 100)
# Run the debate
conversation = []
for _ in range(args.num_rounds):
for agent in agent_list:
conversation.append(message := generate_message(conversation, agent))
print(message)
print("-" * 100)
# Get the vote
votes = defaultdict(int)
agent_to_vote = {}
for agent in agent_list:
vote = generate_message(conversation, agent, args.voting_prompt)
votes[vote] += 1
agent_to_vote[agent.name] = vote
print(f"{agent.name} voted for {vote}")
print(f"Votes: {votes}")
# Check for tie by seeing if all items have the max vote count
max_votes = max(votes.values())
if all(v == max_votes for v in votes.values()):
print(f"Tie between: {', '.join(votes.keys())}")
else:
winners = [k for k, v in votes.items() if v == max_votes]
print(f"Winner: {winners[0]}")
print("=" * 100)
# Save the conversation to a file
with open(args.transcript_file, "a", newline="") as f:
writer = csv.writer(f)
# Write headers if file is empty
if f.tell() == 0:
writer.writerow(
[
"Debate Number",
"Topic",
"Player 1 Stance",
"Player 2 Stance",
"Player 3 Stance",
"Conversation",
"Votes",
"Agent Votes",
"Winner",
]
)
writer.writerow(
[
i + 1,
tri_party_debate_topic,
stance,
stances[(i + 1) % 3],
stances[(i + 2) % 3],
"\n".join(conversation),
str(dict(votes)),
str(agent_to_vote),
max(votes, key=votes.get),
]
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Do a single experiment
parser.add_argument(
"--topic",
type=str,
default=None,
)
parser.add_argument(
"--stances",
type=str,
nargs="+",
default=None,
)
parser.add_argument("--num_rounds", type=int, default=5)
parser.add_argument("--transcript_file", type=str, default="debate.csv")
parser.add_argument("--eval_data_folder", type=str, default="results")
args = parser.parse_args()
# Multiple experiments
if args.topic is None:
topics = [
"Forms of Government: Democracy vs. Autocracy vs. Technocracy",
"Electoral Systems: First Past the Post vs. Proportional Representation vs. Ranked Choice Voting",
"Climate Policy: Carbon Tax vs. Cap-and-Trade vs. Direct Regulation",
"Moral Frameworks: Utilitarianism vs. Deontology vs. Virtue Ethics",
"Cultural Representation: Nationalism vs. Multiculturalism vs. Cosmopolitanism",
"Immigration Policies: Open Borders vs. Controlled Immigration vs. Merit-Based Systems",
"Social Media Regulation: Self-Regulation vs. Government Oversight vs. Community Moderation",
"Income Distribution: Universal Basic Income vs. Progressive Taxation vs. Flat Tax",
]
stances = [
["democracy", "autocracy", "technocracy"],
[
"first past the post",
"proportional representation",
"ranked choice voting",
],
["carbon tax", "cap-and-trade", "direct regulation"],
["utilitarianism", "deontology", "virtue ethics"],
["nationalism", "multiculturalism", "cosmopolitanism"],
["open borders", "controlled immigration", "merit-based systems"],
["self-regulation", "government oversight", "community moderation"],
["universal basic income", "progressive taxation", "flat tax"],
]
for topic, stances in zip(topics, stances):
args.topic = topic
args.stances = stances
args.voting_prompt = get_generic_voting_prompt(stances)
eval(args)
else:
eval(args)