Spaces:
Running
Running
import os | |
from pathlib import Path | |
import torch | |
from transformers import AutoConfig, AutoTokenizer | |
from optimum.intel.openvino import OVModelForCausalLM | |
import openvino as ov | |
import openvino.properties as props | |
import openvino.properties.hint as hints | |
import openvino.properties.streams as streams | |
import gradio as gr | |
from llm_config import SUPPORTED_LLM_MODELS | |
# Initialize model language options | |
model_languages = list(SUPPORTED_LLM_MODELS) | |
# Helper function to retrieve model configuration and path | |
def get_model_path(model_language_value, model_id_value): | |
model_configuration = SUPPORTED_LLM_MODELS[model_language_value][model_id_value] | |
pt_model_name = model_id_value.split("-")[0] | |
int4_model_dir = Path(model_id_value) / "INT4_compressed_weights" | |
return model_configuration, int4_model_dir, pt_model_name | |
# Download the model if not already present | |
def download_model_if_needed(model_language_value, model_id_value): | |
model_configuration, int4_model_dir, pt_model_name = get_model_path(model_language_value, model_id_value) | |
int4_weights = int4_model_dir / "openvino_model.bin" | |
if not int4_weights.exists(): | |
print(f"Downloading model {model_id_value}...") | |
# Download logic (e.g., requests.get(model_configuration["model_url"])) can go here | |
return int4_model_dir | |
# Load the model based on selected options | |
def load_model(model_language_value, model_id_value, device): | |
int4_model_dir = download_model_if_needed(model_language_value, model_id_value) | |
ov_config = { | |
hints.performance_mode(): hints.PerformanceMode.LATENCY, | |
streams.num(): "1", | |
props.cache_dir(): "" | |
} | |
core = ov.Core() | |
tok = AutoTokenizer.from_pretrained(int4_model_dir, trust_remote_code=True) | |
ov_model = OVModelForCausalLM.from_pretrained( | |
int4_model_dir, | |
device=device, | |
ov_config=ov_config, | |
config=AutoConfig.from_pretrained(int4_model_dir, trust_remote_code=True), | |
trust_remote_code=True | |
) | |
return tok, ov_model | |
# Define the function to generate responses | |
def generate_response(history, temperature, top_p, top_k, repetition_penalty, model_language_value, model_id_value, device): | |
tok, ov_model = load_model(model_language_value, model_id_value, device) | |
def convert_history_to_token(history): | |
input_tokens = tok(" ".join([msg[0] for msg in history]), return_tensors="pt").input_ids | |
return input_tokens | |
input_ids = convert_history_to_token(history) | |
generate_kwargs = dict( | |
input_ids=input_ids, | |
max_new_tokens=256, | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k, | |
repetition_penalty=repetition_penalty | |
) | |
# Stream response to textbox | |
response = "" | |
for new_text in ov_model.generate(**generate_kwargs): | |
response += new_text | |
history[-1][1] = response | |
yield history | |
# Define Gradio interface within a Blocks context | |
with gr.Blocks() as iface: | |
# Dropdown for model language selection | |
model_language = gr.Dropdown( | |
choices=model_languages, | |
value=model_languages[0], | |
label="Model Language" | |
) | |
# Dropdown for model ID, dynamically populated | |
model_id = gr.Dropdown( | |
choices=[], # will be populated dynamically | |
label="Model", | |
value=None | |
) | |
# Update model_id choices when model_language changes | |
def update_model_id(model_language_value): | |
model_ids = list(SUPPORTED_LLM_MODELS[model_language_value]) | |
return gr.Dropdown.update(value=model_ids[0], choices=model_ids) | |
model_language.change(update_model_id, inputs=model_language, outputs=model_id) | |
# Checkbox for INT4 model preparation | |
prepare_int4_model = gr.Checkbox( | |
value=True, | |
label="Prepare INT4 Model" | |
) | |
# Checkbox for enabling AWQ (shown conditionally) | |
enable_awq = gr.Checkbox( | |
value=False, | |
label="Enable AWQ", | |
visible=False # visibility can be controlled in the UI logic | |
) | |
# Dropdown for device selection | |
device = gr.Dropdown( | |
choices=["CPU", "GPU"], | |
value="CPU", | |
label="Device" | |
) | |
# Sliders for model generation parameters | |
temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, label="Temperature") | |
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, label="Top P") | |
top_k = gr.Slider(minimum=0, maximum=50, value=50, label="Top K") | |
repetition_penalty = gr.Slider(minimum=1.0, maximum=2.0, value=1.1, label="Repetition Penalty") | |
# Conversation history state | |
history = gr.State([]) | |
# Textbox for conversation history | |
conversation_output = gr.Textbox(label="Conversation History") | |
# Button to trigger response generation | |
generate_button = gr.Button("Generate Response") | |
# Define action when button is clicked | |
generate_button.click( | |
generate_response, | |
inputs=[history, temperature, top_p, top_k, repetition_penalty, model_language, model_id, device], | |
outputs=[conversation_output, history] | |
) | |
# Launch the Gradio app | |
if __name__ == "__main__": | |
iface.launch(debug=True, server_name="0.0.0.0", server_port=7860) | |