llm-chatbot / app.py
lightmate's picture
Update app.py
9b61493 verified
raw
history blame
5.99 kB
import os
from pathlib import Path
import requests
import shutil
import torch
from threading import Event, Thread
from transformers import AutoConfig, AutoTokenizer
from optimum.intel.openvino import OVModelForCausalLM
import openvino as ov
import openvino.properties as props
import openvino.properties.hint as hints
import openvino.properties.streams as streams
import gradio as gr
from llm_config import SUPPORTED_LLM_MODELS
# Initialize model language options
model_languages = list(SUPPORTED_LLM_MODELS)
# Gradio Interface inside Blocks
with gr.Blocks() as iface:
model_language = gr.Dropdown(
choices=model_languages,
value=model_languages[0],
label="Model Language"
)
model_id = gr.Dropdown(
choices=[], # will be dynamically populated
label="Model",
value=None
)
# Function to update model_id dropdown choices based on model_language
def update_model_id(model_language_value):
model_ids = list(SUPPORTED_LLM_MODELS[model_language_value])
return gr.update(value=model_ids[0], choices=model_ids)
model_language.change(update_model_id, inputs=model_language, outputs=model_id)
# Gradio checkbox for preparing INT4 model
prepare_int4_model = gr.Checkbox(
value=True,
label="Prepare INT4 Model"
)
# Gradio checkbox for enabling AWQ (depends on INT4 checkbox)
enable_awq = gr.Checkbox(
value=False,
label="Enable AWQ",
visible=False
)
# Gradio dropdown for device selection
device = gr.Dropdown(
choices=["CPU", "GPU"],
value="CPU",
label="Device"
)
# Model directory and setup based on selections
def get_model_path(model_language_value, model_id_value):
model_configuration = SUPPORTED_LLM_MODELS[model_language_value][model_id_value]
pt_model_id = model_configuration["model_id"]
pt_model_name = model_id_value.split("-")[0]
int4_model_dir = Path(model_id_value) / "INT4_compressed_weights"
return model_configuration, int4_model_dir, pt_model_name
# Function to download the model if not already present
def download_model_if_needed(model_language_value, model_id_value):
model_configuration, int4_model_dir, pt_model_name = get_model_path(model_language_value, model_id_value)
int4_weights = int4_model_dir / "openvino_model.bin"
if not int4_weights.exists():
print(f"Downloading model {model_id_value}...")
# Add your download logic here (e.g., from a URL)
# Example:
# r = requests.get(model_configuration["model_url"])
# with open(int4_weights, "wb") as f:
# f.write(r.content)
return int4_model_dir
# Load the model
def load_model(model_language_value, model_id_value):
int4_model_dir = download_model_if_needed(model_language_value, model_id_value)
# Load the OpenVINO model
ov_config = {hints.performance_mode(): hints.PerformanceMode.LATENCY, streams.num(): "1", props.cache_dir(): ""}
core = ov.Core()
model_dir = int4_model_dir
model_configuration = SUPPORTED_LLM_MODELS[model_language_value][model_id_value]
tok = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
ov_model = OVModelForCausalLM.from_pretrained(
model_dir,
device=device.value, # Use Gradio dropdown value for device
ov_config=ov_config,
config=AutoConfig.from_pretrained(model_dir, trust_remote_code=True),
trust_remote_code=True
)
return tok, ov_model, model_configuration
# Gradio UI for temperature and other model parameters
temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, label="Temperature")
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, label="Top P")
top_k = gr.Slider(minimum=0, maximum=50, value=50, label="Top K")
repetition_penalty = gr.Slider(minimum=1.0, maximum=2.0, value=1.1, label="Repetition Penalty")
# Conversation history input/output
history = gr.State([]) # store the conversation history
# Gradio function for generating responses
def generate_response(history, temperature, top_p, top_k, repetition_penalty, model_language_value, model_id_value):
tok, ov_model, model_configuration = load_model(model_language_value, model_id_value)
def convert_history_to_token(history):
input_tokens = tok(" ".join([msg[0] for msg in history]), return_tensors="pt").input_ids
return input_tokens
input_ids = convert_history_to_token(history)
streamer = gr.Textbox.update()
generate_kwargs = dict(
input_ids=input_ids,
max_new_tokens=256,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
streamer=streamer
)
event = Event()
def generate_and_signal_complete():
ov_model.generate(**generate_kwargs)
event.set()
t1 = Thread(target=generate_and_signal_complete)
t1.start()
partial_text = ""
for new_text in streamer:
partial_text += new_text
history[-1][1] = partial_text
yield history
# Interface setup
iface = gr.Interface(
fn=generate_response,
inputs=[
history,
temperature,
top_p,
top_k,
repetition_penalty,
model_language,
model_id
],
outputs=[gr.Textbox(label="Conversation History")],
live=True,
title="OpenVINO Chatbot"
)
# Launch Gradio app
if __name__ == "__main__":
iface.launch(debug=True, share=True, server_name="0.0.0.0", server_port=7860)