File size: 3,824 Bytes
e3ced5a
 
49c812e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68adc61
49c812e
 
 
 
e3ced5a
49c812e
 
e3ced5a
49c812e
 
 
e3ced5a
 
 
49c812e
 
 
 
 
 
 
 
 
e3ced5a
49c812e
 
 
 
 
e3ced5a
 
 
 
 
49c812e
e3ced5a
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import gradio as gr

import os
import json

from unsloth import FastLanguageModel
import torch
max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally!
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "sayeed99/meta-llama3-8b-xtherapy-bnb-4bit", # YOUR MODEL YOU USED FOR TRAINING
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
)
FastLanguageModel.for_inference(model) # Enable native 2x faster inference

# alpaca_prompt = You MUST copy from above!
formatted_string = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>You are Anna, a helpful AI assistant for mental therapy assistance developed by a team of developers at xnetics. If you do not know the user's name, start by asking the name. If you do not know details about user, ask them."

# Function to format the string
def format_chat_data(data):
    formatted_output = []
    if data["role"] == "assistant":
        value = data["content"]
        formatted_output.append("<|eot_id|><|start_header_id|>assistant<|end_header_id|>" + value)
    else:
        formatted_output.append("<|eot_id|><|start_header_id|>user<|end_header_id|>" + data["content"])

    return "".join(formatted_output)

def formatting_prompts_funcV2(examples):
    conversations = examples
    text = formatted_string
    for conversation in conversations:
        # Must add EOS_TOKEN, otherwise your generation will go on forever!
        text = text + format_chat_data(conversation)
    return text

def get_last_assistant_message(text):
    # Split the text by 'assistant' to isolate assistant's messages
    parts = text.split('<|start_header_id|>assistant<|end_header_id|>')
    
    # The last part is the last assistant message
    # Remove leading/trailing whitespace and return
    last_message = parts[-1].strip()
    last_message = cleanup(last_message)
    return last_message


def cleanup(text):
    # Check if the string ends with 'eot_id'
    if text.endswith('<|eot_id|>'):
        # Remove the last 10 characters
        return text[:-10]
    else:
        return text

# Define a function to handle the conversation and update the session
def handle_conversation(user_input):

    historyPrompt = formatting_prompts_funcV2(user_input)

    historyPrompt = historyPrompt + "<|eot_id|><|start_header_id|>assistant<|end_header_id|>"
    inputs = tokenizer(
        [
            historyPrompt
        ], return_tensors="pt").to("cuda")

    outputs = model.generate(**inputs, max_new_tokens=512, use_cache=True)
    decoded_outputs = tokenizer.batch_decode(outputs)[0]
    # decoded_outputs = "Hello Welcome"
    last_message = get_last_assistant_message(decoded_outputs)

    # Return the AI response
    return last_message

def complete(messages):
    ai_response = handle_conversation(messages)
    return ai_response



def predict(message, history):
    history_openai_format = []
    for human, assistant in history:
        history_openai_format.append({"role": "user", "content": human })
        history_openai_format.append({"role": "assistant", "content":assistant})
    history_openai_format.append({"role": "user", "content": message})
  
    response = complete(history_openai_format)
    print(response)

    partial_message = ""
    for chunk in response:
        if chunk is not None:
              partial_message = partial_message + chunk
              yield partial_message

"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
    predict
)


if __name__ == "__main__":
    demo.launch()