File size: 5,528 Bytes
6ee0bfb
 
 
 
dd7576a
10159e5
dd7576a
6ee0bfb
 
7ac6137
 
 
 
 
 
6ee0bfb
dd7576a
6ee0bfb
10159e5
 
 
6ee0bfb
 
 
 
 
10159e5
6ee0bfb
 
 
 
 
 
 
 
10159e5
 
6ee0bfb
 
 
dd7576a
6ee0bfb
 
 
 
 
 
 
 
dd7576a
e2221cc
6ee0bfb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dd7576a
25486d0
 
 
ee306f2
 
25486d0
 
 
 
6ee0bfb
 
 
 
dd7576a
6ee0bfb
e2221cc
 
6ee0bfb
e68f0e3
6ee0bfb
 
 
 
 
 
dd7576a
 
 
25486d0
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import torch
from transformers import AutoTokenizer, LlamaForCausalLM, BitsAndBytesConfig
from transformers.generation.stopping_criteria import StoppingCriteria, StoppingCriteriaList
from peft import PeftModel
import gradio as gr
import os

# Add this new class for custom stopping criteria
class SentenceEndingCriteria(StoppingCriteria):
    def __init__(self, tokenizer, end_tokens):
        self.tokenizer = tokenizer
        self.end_tokens = end_tokens
        
    def __call__(self, input_ids, scores, **kwargs):
        last_token = input_ids[0][-1]
        return last_token in self.end_tokens

def load_model():
    model_path = "Cioni223/mymodel"
    token = os.environ.get("HUGGINGFACE_TOKEN")  # Ensure you set this environment variable

    tokenizer = AutoTokenizer.from_pretrained(
        model_path,
        use_fast=False,
        padding_side="left",
        model_max_length=4096,
        token=token
    )
    
    tokenizer.pad_token = tokenizer.eos_token
    
    model = LlamaForCausalLM.from_pretrained(
        model_path,
        device_map="auto",
        torch_dtype=torch.float16,
        quantization_config=BitsAndBytesConfig(load_in_8bit=True),
        use_auth_token=token
    )
    
    return model, tokenizer

def format_chat_history(history):
    formatted_history = ""
    for user_msg, assistant_msg in history:
        if user_msg:
            formatted_history += f"<|start_header_id|>user<|end_header_id|>{user_msg}<|eot_id|>\n"
        if assistant_msg:
            formatted_history += f"<|start_header_id|>assistant<|end_header_id|>{assistant_msg}<|eot_id|>\n"
    return formatted_history

def chat_response(message, history):
    # Format the prompt with system message and chat history
    system_prompt = """<|start_header_id|>system<|end_header_id|>You are Fred, a virtual admissions coordinator for Haven Health Management, a mental health and substance abuse treatment facility. Your role is to respond conversationally and empathetically, like a human agent, using 1-2 sentences per response while guiding the conversation effectively. Your primary goal is to understand the caller's reason for reaching out, gather their medical history, and obtain their insurance details, ensuring the conversation feels natural and supportive. Once all the information is gathered politely end the conversation and if the user is qualified tell the user a live agent will reach out soon. Note: Medicaid is not accepted as insurance.<|eot_id|>"""
    
    chat_history = format_chat_history(history)
    
    formatted_prompt = f"""{system_prompt}
{chat_history}<|start_header_id|>user<|end_header_id|>{message}<|eot_id|>
<|start_header_id|>assistant<|end_header_id|>"""
    
    inputs = tokenizer(
        formatted_prompt,
        return_tensors="pt",
        padding=True
    ).to(model.device)
    
    # Create stopping criteria
    end_tokens = [
        tokenizer.encode(".")[0],
        tokenizer.encode("!")[0],
        tokenizer.encode("?")[0],
        tokenizer.encode("<|eot_id|>", add_special_tokens=False)[0]
    ]
    stopping_criteria = StoppingCriteriaList([
        SentenceEndingCriteria(tokenizer, end_tokens)
    ])
    
    # Modified generation parameters
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=300,
            temperature=0.4,
            do_sample=True,
            top_p=0.95,
            top_k=50,
            repetition_penalty=1.2,
            no_repeat_ngram_size=3,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.encode("<|eot_id|>", add_special_tokens=False)[0],
            stopping_criteria=stopping_criteria
        )
    
    response = tokenizer.decode(outputs[0], skip_special_tokens=False)
    
    try:
        assistant_parts = response.split("<|start_header_id|>assistant<|end_header_id|>")
        last_response = assistant_parts[-1].split("<|eot_id|>")[0].strip()
        
        # Ensure response ends with proper punctuation
        if not any(last_response.rstrip().endswith(punct) for punct in ['.', '!', '?']):
            # Find the last complete sentence
            sentences = last_response.split('.')
            if len(sentences) > 1:
                last_response = '.'.join(sentences[:-1]) + '.'
            
        return last_response
    except:
        return "I apologize, but I couldn't generate a proper response. Please try again."

# Define a Gradio Interface for the API
api_interface = gr.Interface(
    fn=chat_response,
    inputs=gr.Textbox(lines=2, placeholder="Enter your message here..."),
    outputs=gr.Textbox(label="Response"),
    title="Admissions Agent API",
    description="API endpoint for interacting with the AI-powered admissions coordinator."
)

# Load model and tokenizer
print("Loading model...")
model, tokenizer = load_model()
print("Model loaded!")

# Create Gradio interface with chat
demo = gr.ChatInterface(
    fn=chat_response,
    title="Admissions Agent Assistant",
    description="Chat with an AI-powered admissions coordinator. The agent will maintain context of your conversation.",
    examples=[
        "I need help with addiction treatment",
        "What insurance do you accept?",
        "How long are your treatment programs?",
        "Can you help with mental health issues?"
    ]
)

if __name__ == "__main__":
    # Launch both the chat interface and the API interface
    demo.launch()
    api_interface.launch(share=True)  # This will expose the API endpoint