-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfrankenstein.py
125 lines (87 loc) · 4.55 KB
/
frankenstein.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import discord
from transformers import AutoTokenizer, AutoModelForCausalLM
from pre_process_text import process_msg_text
from utils import load_secrets, BOS, SEP, EOS
class MimicBot(discord.Client):
"""
Created after loosely following Thomas Chaigneau's article: [Building and Launching Your Discord Bot: A Step-by-Step Guide](https://medium.com/@thomaschaigneau.ai/building-and-launching-your-discord-bot-a-step-by-step-guide-f803f7943d33).
Mixed with this Minimal Bot in [the discordpy docs](https://discordpy.readthedocs.io/en/stable/quickstart.html#a-minimal-bot).
"""
def __init__(self, intents : "discord.Intents", username:str, model_prompt:str) -> None:
"""
Constructor.
Input:
- intents: discord intents object with message content set to True.
- username: Username of model to load.
"""
super().__init__(intents=intents)
self.username = username
self.model_prompt = model_prompt
self.tokenizer = AutoTokenizer.from_pretrained(f"models/gpt2/{username}/tokenizer")
self.model = AutoModelForCausalLM.from_pretrained(f"models/gpt2/{username}/model")
return
def __get_model_output(self, message_text:str) -> str:
"""
Function to get model output from an input string.
Input:
- message_text : The text of the message received.
Output:
- model_response : The output of the model.
"""
# Use pre-processing function to ensure same preprocessing as the model was trained on.
msg_formatted = {"content" : message_text, "mentions" : [] }
preprocessed_message = process_msg_text(msg_formatted)
# Format input as the text before the separator token so the model only needs to predict the response
# NOTE: A ternary is used to add a period at the end of the prompt text. Empirically better performance
preprocessed_message = f"{BOS} " + preprocessed_message + '.' if preprocessed_message[-1] != '.' else "" + f" {SEP} "
# Tokenize preprocessed message and get response. Model params are hard-coded for now.
inputs = self.tokenizer(preprocessed_message, return_tensors="pt").input_ids
outputs = self.model.generate(inputs,
max_new_tokens=200,
do_sample=True,
top_p=0.97,
top_k=150,
temperature=1.0
)
decoded_output = self.tokenizer.batch_decode(outputs, skip_special_tokens=False)[0]
# We only want the model's response (after SEP token). Including prompt is redundant.
first_sep_id = decoded_output.find(f"{SEP}")
end_id = decoded_output.find(f"{EOS}")
model_response = decoded_output[first_sep_id + len(SEP): end_id].replace("<unk>", "'")
return model_response
async def on_ready(self) -> None:
"""
Ran when bot is alive and connected.
Input:
- None
Output:
- None (prints to console)
"""
print(f'Logged on as {self.user}!')
return
async def on_message(self, message:"discord.Message") -> None:
"""
Event Handler for when a message is sent.
Handles sending of response message as well.
Input:
- message: Message object received from discord API through discord library.
Output:
- None (sends message through discord library functionality)
"""
if (message.content.startswith(self.model_prompt)):
# Remove the model prompt of the string.
message_text = message.content[len(self.model_prompt) : ]
model_response = self.__get_model_output(message_text)
await message.channel.send(f"@{message.author} {model_response}")
def main():
secrets = load_secrets()
# Set the username of the model to load.
username = ...
# Set the prompt token (preceed message to bot with this sequence to get model to read it)
model_prompt = f"${username}"
intents = discord.Intents.default()
intents.message_content = True
client = MimicBot(intents=intents, username=username, model_prompt=model_prompt)
client.run(secrets["bot_token"])
if (__name__ == "__main__"):
main()