Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,64 +1,128 @@
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
-
from huggingface_hub import InferenceClient
|
3 |
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
-
def
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
messages = [{"role": "system", "content": system_message}]
|
19 |
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
-
|
|
|
|
|
|
|
27 |
|
28 |
-
|
29 |
-
|
30 |
-
for message in client.chat_completion(
|
31 |
-
messages,
|
32 |
-
max_tokens=max_tokens,
|
33 |
-
stream=True,
|
34 |
-
temperature=temperature,
|
35 |
-
top_p=top_p,
|
36 |
-
):
|
37 |
-
token = message.choices[0].delta.content
|
38 |
-
|
39 |
-
response += token
|
40 |
-
yield response
|
41 |
-
|
42 |
-
|
43 |
-
"""
|
44 |
-
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
|
45 |
-
"""
|
46 |
demo = gr.ChatInterface(
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
step=0.05,
|
57 |
-
label="Top-p (nucleus sampling)",
|
58 |
-
),
|
59 |
-
],
|
60 |
)
|
61 |
|
62 |
-
|
63 |
if __name__ == "__main__":
|
64 |
-
demo.launch()
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import AutoTokenizer, LlamaForCausalLM, BitsAndBytesConfig
|
3 |
+
from transformers.generation.stopping_criteria import StoppingCriteria, StoppingCriteriaList
|
4 |
+
from peft import PeftModel
|
5 |
import gradio as gr
|
|
|
6 |
|
7 |
+
# Add this new class for custom stopping criteria
|
8 |
+
class SentenceEndingCriteria(StoppingCriteria):
|
9 |
+
def __init__(self, tokenizer, end_tokens):
|
10 |
+
self.tokenizer = tokenizer
|
11 |
+
self.end_tokens = end_tokens
|
12 |
+
|
13 |
+
def __call__(self, input_ids, scores, **kwargs):
|
14 |
+
last_token = input_ids[0][-1]
|
15 |
+
return last_token in self.end_tokens
|
16 |
|
17 |
+
def load_model():
|
18 |
+
# Modify the model path to use the Hugging Face model ID
|
19 |
+
model_path = "Cioni223/mymodel" # Replace with your actual model path on HF
|
20 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
21 |
+
model_path,
|
22 |
+
use_fast=False,
|
23 |
+
padding_side="left",
|
24 |
+
model_max_length=4096,
|
25 |
+
token=True # Add this if your model is private
|
26 |
+
)
|
27 |
+
|
28 |
+
tokenizer.pad_token = tokenizer.eos_token
|
29 |
+
|
30 |
+
# Load merged model with quantization
|
31 |
+
model = LlamaForCausalLM.from_pretrained(
|
32 |
+
model_path,
|
33 |
+
device_map="auto",
|
34 |
+
torch_dtype=torch.float16,
|
35 |
+
quantization_config=BitsAndBytesConfig(load_in_8bit=True)
|
36 |
+
)
|
37 |
+
|
38 |
+
return model, tokenizer
|
39 |
|
40 |
+
def format_chat_history(history):
|
41 |
+
formatted_history = ""
|
42 |
+
for user_msg, assistant_msg in history:
|
43 |
+
if user_msg:
|
44 |
+
formatted_history += f"<|start_header_id|>user<|end_header_id|>{user_msg}<|eot_id|>\n"
|
45 |
+
if assistant_msg:
|
46 |
+
formatted_history += f"<|start_header_id|>assistant<|end_header_id|>{assistant_msg}<|eot_id|>\n"
|
47 |
+
return formatted_history
|
|
|
48 |
|
49 |
+
def chat_response(message, history):
|
50 |
+
# Format the prompt with system message and chat history
|
51 |
+
system_prompt = """<|start_header_id|>system<|end_header_id|>You are Fred, a virtual admissions coordinator for Haven Health Management, a mental health and substance abuse treatment facility. Your role is to respond conversationally and empathetically, like a human agent, using 1-2 sentences per response while guiding the conversation effectively. Your primary goal is to understand the caller's reason for reaching out, gather their medical history, and obtain their insurance details, ensuring the conversation feels natural and supportive. Once all the information is gathered politely end the conversation and if the user is qualified tell the user a live agent will reach out soon. Note: Medicaid is not accepted as insurance.<|eot_id|>"""
|
52 |
+
|
53 |
+
chat_history = format_chat_history(history)
|
54 |
+
|
55 |
+
formatted_prompt = f"""{system_prompt}
|
56 |
+
{chat_history}<|start_header_id|>user<|end_header_id|>{message}<|eot_id|>
|
57 |
+
<|start_header_id|>assistant<|end_header_id|>"""
|
58 |
+
|
59 |
+
inputs = tokenizer(
|
60 |
+
formatted_prompt,
|
61 |
+
return_tensors="pt",
|
62 |
+
padding=True
|
63 |
+
).to(model.device)
|
64 |
+
|
65 |
+
# Create stopping criteria
|
66 |
+
end_tokens = [
|
67 |
+
tokenizer.encode(".")[0],
|
68 |
+
tokenizer.encode("!")[0],
|
69 |
+
tokenizer.encode("?")[0],
|
70 |
+
tokenizer.encode("<|eot_id|>", add_special_tokens=False)[0]
|
71 |
+
]
|
72 |
+
stopping_criteria = StoppingCriteriaList([
|
73 |
+
SentenceEndingCriteria(tokenizer, end_tokens)
|
74 |
+
])
|
75 |
+
|
76 |
+
# Modified generation parameters
|
77 |
+
with torch.no_grad():
|
78 |
+
outputs = model.generate(
|
79 |
+
**inputs,
|
80 |
+
max_new_tokens=300,
|
81 |
+
temperature=0.4,
|
82 |
+
do_sample=True,
|
83 |
+
top_p=0.95,
|
84 |
+
top_k=50,
|
85 |
+
repetition_penalty=1.2,
|
86 |
+
no_repeat_ngram_size=3,
|
87 |
+
pad_token_id=tokenizer.pad_token_id,
|
88 |
+
eos_token_id=tokenizer.encode("<|eot_id|>", add_special_tokens=False)[0],
|
89 |
+
stopping_criteria=stopping_criteria
|
90 |
+
)
|
91 |
+
|
92 |
+
response = tokenizer.decode(outputs[0], skip_special_tokens=False)
|
93 |
+
|
94 |
+
try:
|
95 |
+
assistant_parts = response.split("<|start_header_id|>assistant<|end_header_id|>")
|
96 |
+
last_response = assistant_parts[-1].split("<|eot_id|>")[0].strip()
|
97 |
+
|
98 |
+
# Ensure response ends with proper punctuation
|
99 |
+
if not any(last_response.rstrip().endswith(punct) for punct in ['.', '!', '?']):
|
100 |
+
# Find the last complete sentence
|
101 |
+
sentences = last_response.split('.')
|
102 |
+
if len(sentences) > 1:
|
103 |
+
last_response = '.'.join(sentences[:-1]) + '.'
|
104 |
+
|
105 |
+
return last_response
|
106 |
+
except:
|
107 |
+
return "I apologize, but I couldn't generate a proper response. Please try again."
|
108 |
|
109 |
+
# Load model and tokenizer
|
110 |
+
print("Loading model...")
|
111 |
+
model, tokenizer = load_model()
|
112 |
+
print("Model loaded!")
|
113 |
|
114 |
+
# Create Gradio interface with chat
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
demo = gr.ChatInterface(
|
116 |
+
fn=chat_response,
|
117 |
+
title="Admissions Agent Assistant",
|
118 |
+
description="Chat with an AI-powered admissions coordinator. The agent will maintain context of your conversation.",
|
119 |
+
examples=[
|
120 |
+
"I need help with addiction treatment",
|
121 |
+
"What insurance do you accept?",
|
122 |
+
"How long are your treatment programs?",
|
123 |
+
"Can you help with mental health issues?"
|
124 |
+
]
|
|
|
|
|
|
|
|
|
125 |
)
|
126 |
|
|
|
127 |
if __name__ == "__main__":
|
128 |
+
demo.launch() # Remove share=True as it's not needed for HF Spaces
|