SpaceLLaVA / app.py
Salma Mayorquin
initial commit
0945ad6
raw
history blame
3.08 kB
import io
import base64
import numpy as np
import torch
import matplotlib
import matplotlib.cm
import gradio as gr
from PIL import Image
from llama_cpp import Llama
from llama_cpp.llama_chat_format import Llava15ChatHandler
# Converts an image input (PIL Image or file path) into a base64 data URI
def image_to_base64_data_uri(image_input):
if isinstance(image_input, str):
with open(image_input, "rb") as img_file:
base64_data = base64.b64encode(img_file.read()).decode('utf-8')
elif isinstance(image_input, Image.Image):
buffer = io.BytesIO()
image_input.save(buffer, format="PNG")
base64_data = base64.b64encode(buffer.getvalue()).decode('utf-8')
else:
raise ValueError("Unsupported input type. Input must be a file path or a PIL.Image.Image instance.")
return f"data:image/png;base64,{base64_data}"
class Llava:
def __init__(self, mmproj="model/mmproj-model-f16.gguf", model_path="model/ggml-model-q4_0.gguf", gpu=False):
chat_handler = Llava15ChatHandler(clip_model_path=mmproj, verbose=True)
n_gpu_layers = 0
if gpu:
n_gpu_layers = -1
self.llm = Llama(model_path=model_path, chat_handler=chat_handler, n_ctx=2048, logits_all=True, n_gpu_layers=n_gpu_layers)
def run_inference(self, image, prompt):
data_uri = image_to_base64_data_uri(image)
res = self.llm.create_chat_completion(
messages=[
{"role": "system", "content": "You are an assistant who perfectly describes images."},
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": data_uri}},
{"type": "text", "text": prompt}
]
}
]
)
return res["choices"][0]["message"]["content"]
# Initialize the model
llm_model = Llava()
title_and_links_markdown = """
# 🛸SpaceLLaVA🌋: A spatial reasoning multi-modal model
This space hosts our initial release of LLaVA 1.5 LoRA tuned for spatial reasoning using data generated with [VQASynth](https://github.com/remyxai/VQASynth).
Upload an image and ask a question.
[Model](https://huggingface.co/remyxai/SpaceLLaVA) | [Code](https://github.com/remyxai/VQASynth) | [Paper](https://spatial-vlm.github.io)
"""
def predict(image, prompt):
result = llm_model.run_inference(image, prompt)
return result
image_input = gr.inputs.Image(type="pil", label="Input Image")
text_input = gr.inputs.Textbox(label="Prompt")
# Initialize interface with examples
iface = gr.Interface(
fn=predict,
inputs=[image_input, text_input],
outputs="text",
title="Llava Model Inference",
description="Input an image and a prompt to receive a description."
)
examples = [
["examples/warehouse_1.jpg", "Is the man wearing gray pants to the left of the pile of boxes on a pallet?"],
["examples/warehouse_2.jpg", "Is the forklift taller than the shelves of boxes?"],
]
iface.examples = examples
iface.launch()