Spaces:
Paused
Paused
import io | |
import base64 | |
import numpy as np | |
import torch | |
import matplotlib | |
import matplotlib.cm | |
import gradio as gr | |
from PIL import Image | |
from llama_cpp import Llama | |
from llama_cpp.llama_chat_format import Llava15ChatHandler | |
# Converts an image input (PIL Image or file path) into a base64 data URI | |
def image_to_base64_data_uri(image_input): | |
if isinstance(image_input, str): | |
with open(image_input, "rb") as img_file: | |
base64_data = base64.b64encode(img_file.read()).decode('utf-8') | |
elif isinstance(image_input, Image.Image): | |
buffer = io.BytesIO() | |
image_input.save(buffer, format="PNG") | |
base64_data = base64.b64encode(buffer.getvalue()).decode('utf-8') | |
else: | |
raise ValueError("Unsupported input type. Input must be a file path or a PIL.Image.Image instance.") | |
return f"data:image/png;base64,{base64_data}" | |
class Llava: | |
def __init__(self, mmproj="model/mmproj-model-f16.gguf", model_path="model/ggml-model-q4_0.gguf", gpu=False): | |
chat_handler = Llava15ChatHandler(clip_model_path=mmproj, verbose=True) | |
n_gpu_layers = 0 | |
if gpu: | |
n_gpu_layers = -1 | |
self.llm = Llama(model_path=model_path, chat_handler=chat_handler, n_ctx=2048, logits_all=True, n_gpu_layers=n_gpu_layers) | |
def run_inference(self, image, prompt): | |
data_uri = image_to_base64_data_uri(image) | |
res = self.llm.create_chat_completion( | |
messages=[ | |
{"role": "system", "content": "You are an assistant who perfectly describes images."}, | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "image_url", "image_url": {"url": data_uri}}, | |
{"type": "text", "text": prompt} | |
] | |
} | |
] | |
) | |
return res["choices"][0]["message"]["content"] | |
llm_model = Llava() | |
def predict(image, prompt): | |
result = llm_model.run_inference(image, prompt) | |
return result | |
title_and_links_markdown = """ | |
# 🛸SpaceLLaVA🌋: A spatial reasoning multi-modal model | |
This space hosts our initial release of LLaVA 1.5 LoRA tuned for spatial reasoning using data generated with [VQASynth](https://github.com/remyxai/VQASynth). | |
Upload an image and ask a question. | |
[Model](https://huggingface.co/remyxai/SpaceLLaVA) | [Code](https://github.com/remyxai/VQASynth) | [Paper](https://spatial-vlm.github.io) | |
""" | |
examples = [ | |
["examples/warehouse_1.jpg", "Is the man wearing gray pants to the left of the pile of boxes on a pallet?"], | |
["examples/warehouse_2.jpg", "Is the forklift taller than the shelves of boxes?"], | |
] | |
# Create the Gradio interface with the markdown | |
iface = gr.Interface( | |
fn=predict, | |
inputs=[gr.Image(type="pil", label="Input Image"), gr.Textbox(label="Prompt")], | |
outputs=gr.Textbox(), | |
examples=examples, | |
title="🛸SpaceLLaVA🌋: A spatial reasoning multi-modal model", | |
description=title_and_links_markdown # Use description for markdown | |
) | |
# Launch the Gradio app | |
iface.launch() | |