Spaces:
Sleeping
Sleeping
import spaces | |
import os | |
import time | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig, AutoProcessor | |
import gradio as gr | |
from threading import Thread | |
from PIL import Image | |
import subprocess | |
# Install flash-attn if not already installed | |
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) | |
# Model and tokenizer for the chatbot | |
MODEL_ID1 = "justinj92/phi-35-vision-burberry" | |
MODEL_LIST1 = ["justinj92/phi-35-vision-burberry"] | |
HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
device = "cuda" if torch.cuda.is_available() else "cpu" # for GPU usage or "cpu" for CPU usage / But you need GPU :) | |
quantization_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_compute_dtype=torch.bfloat16, | |
bnb_4bit_use_double_quant=True, | |
bnb_4bit_quant_type="nf4") | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID1, trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_ID1, | |
torch_dtype=torch.bfloat16, | |
device_map="auto", | |
quantization_config=quantization_config, | |
trust_remote_code=True | |
) | |
# Vision model setup | |
models = { | |
"justinj92/phi-35-vision-burberry": AutoModelForCausalLM.from_pretrained("justinj92/phi-35-vision-burberry", trust_remote_code=True, torch_dtype="auto", _attn_implementation="flash_attention_2").cuda().eval() | |
} | |
processors = { | |
"justinj92/phi-35-vision-burberry": AutoProcessor.from_pretrained("justinj92/phi-35-vision-burberry", trust_remote_code=True) | |
} | |
user_prompt = '\n' | |
assistant_prompt = '\n' | |
prompt_suffix = "\n" | |
# Vision model tab function | |
def stream_vision(image, model_id="justinj92/phi-35-vision-burberry"): | |
model = models[model_id] | |
processor = processors[model_id] | |
text_input="What is shown in this image?" | |
# Prepare the image list and corresponding tags | |
images = [Image.fromarray(image).convert("RGB")] | |
placeholder = "<|image_1|>\n" # Using the image tag as per the example | |
# Construct the prompt with the image tag and the user's text input | |
if text_input: | |
prompt_content = placeholder + text_input | |
else: | |
prompt_content = placeholder | |
messages = [ | |
{"role": "user", "content": prompt_content}, | |
] | |
# Apply the chat template to the messages | |
prompt = processor.tokenizer.apply_chat_template( | |
messages, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
# Process the inputs with the processor | |
inputs = processor(prompt, images, return_tensors="pt").to("cuda:0") | |
# Generation parameters | |
generation_args = { | |
"max_new_tokens": 2000, | |
"temperature": 0.0, | |
"do_sample": False, | |
} | |
# Generate the response | |
generate_ids = model.generate( | |
**inputs, | |
eos_token_id=processor.tokenizer.eos_token_id, | |
**generation_args | |
) | |
# Remove input tokens from the generated response | |
generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:] | |
# Decode the generated output | |
response = processor.batch_decode( | |
generate_ids, | |
skip_special_tokens=True, | |
clean_up_tokenization_spaces=False | |
)[0] | |
return response | |
# CSS for the interface | |
CSS = """ | |
.duplicate-button { | |
margin: auto !important; | |
color: white !important; | |
background: black !important; | |
border-radius: 100vh !important; | |
} | |
h3 { | |
text-align: center; | |
} | |
""" | |
TITLE = "<h1><center>Burberry Product Categorizer</center></h1>" | |
EXPLANATION = """ | |
<div style="text-align: center; margin-top: 20px;"> | |
<p>App uses Microsoft Phi 3.5 Vision Model</p> | |
<p>Fine-Tuned version is built using open Burberry Product dataset.</p> | |
</div> | |
""" | |
footer = """ | |
<div style="text-align: center; margin-top: 20px;"> | |
<a href="https://www.linkedin.com/in/justin-j-4a77456b/" target="_blank">LinkedIn</a> | |
<br> | |
Made with π by Justin J | |
</div> | |
""" | |
# Gradio app with two tabs | |
with gr.Blocks(css=CSS, theme=gr.themes.Default(primary_hue=gr.themes.colors.red, secondary_hue=gr.themes.colors.pink)) as demo: | |
gr.HTML(TITLE) | |
gr.HTML(EXPLANATION) | |
with gr.Tab("Burberry Vision"): | |
with gr.Row(): | |
input_img = gr.Image(label="Upload a Burberry Product Image") | |
with gr.Row(): | |
model_selector = gr.Dropdown(choices=list(models.keys()), label="Model", value="justinj92/phi-35-vision-burberry") | |
# with gr.Row(): | |
# text_input = gr.Textbox(label="Question") | |
with gr.Row(): | |
submit_btn = gr.Button(value="Tell me about this product") | |
with gr.Row(): | |
output_text = gr.Textbox(label="Product Info") | |
submit_btn.click(stream_vision, [input_img, model_selector], [output_text]) | |
gr.HTML(footer) | |
# Launch the combined app | |
demo.launch(debug=True) |