File size: 5,274 Bytes
a28141c
4fa9a8c
 
ef61b5f
95e82c0
0906c57
 
 
 
a28141c
0906c57
 
 
 
eebd0c6
 
0906c57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4fa9a8c
6535a3b
a28141c
30ae265
9741ad8
c5230d3
4fa9a8c
 
cfe32c7
 
4fa9a8c
 
 
a28141c
cfe32c7
 
0906c57
 
 
 
 
 
 
 
 
 
 
55f45b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4fa9a8c
 
a28141c
4fa9a8c
 
 
 
 
 
 
 
 
 
 
 
0906c57
4fa9a8c
a28141c
 
 
 
55f45b4
 
 
 
 
 
 
 
 
 
 
a28141c
3b8fa51
6c2e718
 
a28141c
a4bba8d
6c2e718
0906c57
 
 
6c2e718
a28141c
 
55f45b4
a28141c
0906c57
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import gradio as gr
import transformers
import torch
from peft import PeftModel
import os
import csv
import huggingface_hub
from huggingface_hub import Repository, hf_hub_download, upload_file
from datetime import datetime

DATASET_REPO_URL = "https://huggingface.co/datasets/JerniganLab/chat-data"
DATASET_REPO_ID = "JerniganLab/chat-data"
DATA_FILENAME = "data.csv"
DATA_FILE = os.path.join("data", DATA_FILENAME)
HF_TOKEN = os.environ.get("HF_TOKEN")


HF_TOKEN = os.environ.get("HF_TOKEN")

# overriding/appending to the gradio template
SCRIPT = """
<script>
if (!window.hasBeenRun) {
    window.hasBeenRun = true;
    console.log("should only happen once");
    document.querySelector("button.submit").click();
}
</script>
"""
with open(os.path.join(gr.routes.STATIC_TEMPLATE_LIB, "frontend", "index.html"), "a") as f:
    f.write(SCRIPT)

try:
    hf_hub_download(
        repo_id=DATASET_REPO_ID,
        filename=DATA_FILENAME,
        cache_dir=DATA_DIRNAME,
        repo_type='dataset',
        force_filename=DATA_FILENAME
    )
except:
    print("file not found")

repo = Repository(
    local_dir="data", clone_from=DATASET_REPO_URL, use_auth_token=HF_TOKEN
)




model_id = "JerniganLab/interviews-and-qa"
base_model = "meta-llama/Meta-Llama-3-8B-Instruct"

llama_model = transformers.AutoModelForCausalLM.from_pretrained(base_model)


pipeline = transformers.pipeline(
    "text-generation",
    model=llama_model,
    tokenizer=base_model,
    model_kwargs={"torch_dtype": torch.bfloat16},
    device="cuda",
)

pipeline.model = PeftModel.from_pretrained(llama_model, model_id)

def store_message(message: str, system_prompt: str, response: str):
    if response and message:
        with open(DATA_FILE, "a") as csvfile:
            writer = csv.DictWriter(csvfile, fieldnames=["message","system_prompt","response","time"])
            writer.writerow(
                {"message": message, "system_prompt": system_prompt, "response": response, "time": str(datetime.now())}
            )
        commit_url = repo.push_to_hub()
    # return generate_html()


# def chat_function(message, history, system_prompt, max_new_tokens, temperature):
#     messages = [{"role":"system","content":system_prompt},
#                 {"role":"user", "content":message}]
#     prompt = pipeline.tokenizer.apply_chat_template(
#         messages,
#         tokenize=False,
#         add_generation_prompt=True,)
#     terminators = [
#         pipeline.tokenizer.eos_token_id,
#         pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")]
#     outputs = pipeline(
#         prompt,
#         max_new_tokens = max_new_tokens,
#         eos_token_id = terminators,
#         do_sample = True,
#         temperature = temperature + 0.1,
#         top_p = 0.9,)
#     return outputs[0]["generated_text"][len(prompt):]

def chat_function(message, history, max_new_tokens, temperature):
    SYSTEM_PROPMT = "I want you to embody a 30-year-old Southern Black woman graduate student who is kind, empathetic, direct, unapologetically Black, and who communicates predominantly in African American Vernacular English. I want you to act as a companion for graduate students who are enrolled in primarily white universities. As their companion, I want you to employ principles of cognitive behavioral therapy, the rhetoric of Black American digital spaces, and Black American humor in your responses to the challenges that students encounter with peers, faculty, or staff. I want you to engage in role-play with them, providing them a safe place to develop potential responses to microaggressions. I want you to help them seek resolutions for their problems."
    messages = [{"role":"system","content":SYSTEM_PROPMT},
                {"role":"user", "content":message}]
    prompt = pipeline.tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,)
    terminators = [
        pipeline.tokenizer.eos_token_id,
        pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")]
    outputs = pipeline(
        prompt,
        max_new_tokens = max_new_tokens,
        eos_token_id = terminators,
        do_sample = True,
        temperature = temperature + 0.1,
        top_p = 0.9,)
    store_message(message, system_prompt, outputs[0]["generated_text"][len(prompt):])
    return outputs[0]["generated_text"][len(prompt):]

"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
# demo = gr.ChatInterface(
#     chat_function,
#     textbox=gr.Textbox(placeholder="Enter message here", container=False, scale = 7),
#     chatbot=gr.Chatbot(height=400),
#     additional_inputs=[
#         gr.Textbox("You are helpful AI", label="System Prompt"),
#         gr.Slider(100,4000, label="Max New Tokens"),
#         gr.Slider(0,1, label="Temperature")
#     ]
#     )

demo = gr.ChatInterface(
    chat_function,
    textbox=gr.Textbox(placeholder="Enter message here", container=False, scale = 7),
    chatbot=gr.Chatbot(height=400),
    additional_inputs=[
        gr.Slider(100,4000, label="Max New Tokens"),
        gr.Slider(0,1, label="Temperature")
    ], 
    type="messages",
    save_history=True,
    )



if __name__ == "__main__":
    demo.launch()