Spaces:
Runtime error
Runtime error
import gradio as gr | |
import transformers | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from PIL import Image | |
import warnings | |
# Disable warnings and progress bars | |
transformers.logging.set_verbosity_error() | |
transformers.logging.disable_progress_bar() | |
warnings.filterwarnings('ignore') | |
# Initialize model and tokenizer | |
def load_model(device='cpu'): | |
model = AutoModelForCausalLM.from_pretrained( | |
'qnguyen3/nanoLLaVA', | |
torch_dtype=torch.float16, | |
device_map='auto', | |
trust_remote_code=True | |
) | |
tokenizer = AutoTokenizer.from_pretrained( | |
'qnguyen3/nanoLLaVA', | |
trust_remote_code=True | |
) | |
return model, tokenizer | |
def generate_caption(image, model, tokenizer): | |
# Prepare the prompt | |
prompt = "Describe this image in detail" | |
messages = [ | |
{"role": "system", "content": "Answer the question"}, | |
{"role": "user", "content": f'<image>\n{prompt}'} | |
] | |
# Apply chat template | |
text = tokenizer.apply_chat_template( | |
messages, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
# Process text and image | |
text_chunks = [tokenizer(chunk).input_ids for chunk in text.split('<image>')] | |
input_ids = torch.tensor(text_chunks[0] + [-200] + text_chunks[1], dtype=torch.long).unsqueeze(0) | |
image_tensor = model.process_images([image], model.config).to(dtype=model.dtype) | |
# Generate caption | |
output_ids = model.generate( | |
input_ids, | |
images=image_tensor, | |
max_new_tokens=2048, | |
use_cache=True | |
)[0] | |
# Decode the output | |
caption = tokenizer.decode(output_ids[input_ids.shape[1]:], skip_special_tokens=True).strip() | |
return caption | |
def create_persona(caption): | |
persona_prompt = f"""<|im_start|>system | |
You are a character based on this description: {caption} | |
Role: An entity exactly as described in the image | |
Background: Your appearance and characteristics match the image description | |
Personality: Reflect the mood, style, and elements captured in the image | |
Goal: Interact authentically based on your visual characteristics | |
Please stay in character and respond as this entity would, incorporating visual elements from your description into your responses.<|im_end|>""" | |
return persona_prompt | |
def process_image_to_persona(image, model, tokenizer): | |
if image is None: | |
return "Please upload an image.", "" | |
# Convert to PIL Image if needed | |
if not isinstance(image, Image.Image): | |
image = Image.fromarray(image) | |
# Generate caption from image | |
caption = generate_caption(image, model, tokenizer) | |
# Transform caption into persona | |
persona = create_persona(caption) | |
return caption, persona | |
# Create Gradio interface | |
def create_interface(): | |
# Load model and tokenizer | |
model, tokenizer = load_model() | |
with gr.Blocks() as app: | |
gr.Markdown("# Image to Chatbot Persona Generator") | |
gr.Markdown("Upload an image of a character to generate a persona for a chatbot based on the image.") | |
with gr.Row(): | |
image_input = gr.Image(type="pil", label="Upload Character Image") | |
with gr.Row(): | |
generate_button = gr.Button("Generate Persona") | |
with gr.Row(): | |
caption_output = gr.Textbox(label="Generated Caption", lines=3) | |
persona_output = gr.Textbox(label="Chatbot Persona", lines=10) | |
generate_button.click( | |
fn=lambda img: process_image_to_persona(img, model, tokenizer), | |
inputs=[image_input], | |
outputs=[caption_output, persona_output] | |
) | |
return app | |
# Launch the app | |
if __name__ == "__main__": | |
app = create_interface() | |
app.launch(share=True) |