File size: 5,203 Bytes
657585b
210ec4a
d8164ce
210ec4a
 
 
 
 
 
d8164ce
657585b
d8164ce
 
 
 
 
f0d2584
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1c3f8cd
9b61493
1c3f8cd
9b61493
 
 
 
210ec4a
9b61493
1c3f8cd
9b61493
1c3f8cd
9b61493
 
 
 
f0d2584
9b61493
 
1c3f8cd
9b61493
 
 
1c3f8cd
9b61493
 
 
 
 
1c3f8cd
9b61493
 
 
1c3f8cd
9b61493
 
1c3f8cd
9b61493
 
 
 
 
 
f0d2584
9b61493
 
 
 
 
1c3f8cd
 
9b61493
f0d2584
 
 
 
 
 
 
 
 
 
 
657585b
8f9fe18
1c3f8cd
210ec4a
f0d2584
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import os
from pathlib import Path
import torch
from transformers import AutoConfig, AutoTokenizer
from optimum.intel.openvino import OVModelForCausalLM
import openvino as ov
import openvino.properties as props
import openvino.properties.hint as hints
import openvino.properties.streams as streams
import gradio as gr

from llm_config import SUPPORTED_LLM_MODELS

# Initialize model language options
model_languages = list(SUPPORTED_LLM_MODELS)

# Helper function to retrieve model configuration and path
def get_model_path(model_language_value, model_id_value):
    model_configuration = SUPPORTED_LLM_MODELS[model_language_value][model_id_value]
    pt_model_name = model_id_value.split("-")[0]
    int4_model_dir = Path(model_id_value) / "INT4_compressed_weights"
    return model_configuration, int4_model_dir, pt_model_name

# Download the model if not already present
def download_model_if_needed(model_language_value, model_id_value):
    model_configuration, int4_model_dir, pt_model_name = get_model_path(model_language_value, model_id_value)
    int4_weights = int4_model_dir / "openvino_model.bin"
    if not int4_weights.exists():
        print(f"Downloading model {model_id_value}...")
        # Download logic (e.g., requests.get(model_configuration["model_url"])) can go here
    return int4_model_dir

# Load the model based on selected options
def load_model(model_language_value, model_id_value, device):
    int4_model_dir = download_model_if_needed(model_language_value, model_id_value)
    ov_config = {
        hints.performance_mode(): hints.PerformanceMode.LATENCY,
        streams.num(): "1",
        props.cache_dir(): ""
    }
    core = ov.Core()
    tok = AutoTokenizer.from_pretrained(int4_model_dir, trust_remote_code=True)
    ov_model = OVModelForCausalLM.from_pretrained(
        int4_model_dir,
        device=device,
        ov_config=ov_config,
        config=AutoConfig.from_pretrained(int4_model_dir, trust_remote_code=True),
        trust_remote_code=True
    )
    return tok, ov_model

# Define the function to generate responses
def generate_response(history, temperature, top_p, top_k, repetition_penalty, model_language_value, model_id_value, device):
    tok, ov_model = load_model(model_language_value, model_id_value, device)

    def convert_history_to_token(history):
        input_tokens = tok(" ".join([msg[0] for msg in history]), return_tensors="pt").input_ids
        return input_tokens
    
    input_ids = convert_history_to_token(history)
    generate_kwargs = dict(
        input_ids=input_ids,
        max_new_tokens=256,
        temperature=temperature,
        top_p=top_p,
        top_k=top_k,
        repetition_penalty=repetition_penalty
    )
    
    # Stream response to textbox
    response = ""
    for new_text in ov_model.generate(**generate_kwargs):
        response += new_text
        history[-1][1] = response
        yield history

# Define Gradio interface within a Blocks context
with gr.Blocks() as iface:
    # Dropdown for model language selection
    model_language = gr.Dropdown(
        choices=model_languages,
        value=model_languages[0],
        label="Model Language"
    )

    # Dropdown for model ID, dynamically populated
    model_id = gr.Dropdown(
        choices=[],  # will be populated dynamically
        label="Model",
        value=None
    )

    # Update model_id choices when model_language changes
    def update_model_id(model_language_value):
        model_ids = list(SUPPORTED_LLM_MODELS[model_language_value])
        return gr.Dropdown.update(value=model_ids[0], choices=model_ids)

    model_language.change(update_model_id, inputs=model_language, outputs=model_id)

    # Checkbox for INT4 model preparation
    prepare_int4_model = gr.Checkbox(
        value=True,
        label="Prepare INT4 Model"
    )

    # Checkbox for enabling AWQ (shown conditionally)
    enable_awq = gr.Checkbox(
        value=False,
        label="Enable AWQ",
        visible=False  # visibility can be controlled in the UI logic
    )

    # Dropdown for device selection
    device = gr.Dropdown(
        choices=["CPU", "GPU"],
        value="CPU",
        label="Device"
    )

    # Sliders for model generation parameters
    temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, label="Temperature")
    top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, label="Top P")
    top_k = gr.Slider(minimum=0, maximum=50, value=50, label="Top K")
    repetition_penalty = gr.Slider(minimum=1.0, maximum=2.0, value=1.1, label="Repetition Penalty")

    # Conversation history state
    history = gr.State([])

    # Textbox for conversation history
    conversation_output = gr.Textbox(label="Conversation History")

    # Button to trigger response generation
    generate_button = gr.Button("Generate Response")

    # Define action when button is clicked
    generate_button.click(
        generate_response,
        inputs=[history, temperature, top_p, top_k, repetition_penalty, model_language, model_id, device],
        outputs=[conversation_output, history]
    )

# Launch the Gradio app
if __name__ == "__main__":
    iface.launch(debug=True, server_name="0.0.0.0", server_port=7860)