File size: 3,066 Bytes
0945ad6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6229c52
 
 
 
0945ad6
 
 
 
 
 
 
 
6229c52
 
 
 
0945ad6
432802d
 
0945ad6
6229c52
432802d
 
 
 
 
0945ad6
 
432802d
0945ad6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import io
import base64
import numpy as np
import torch
import matplotlib
import matplotlib.cm
import gradio as gr
from PIL import Image

from llama_cpp import Llama
from llama_cpp.llama_chat_format import Llava15ChatHandler

# Converts an image input (PIL Image or file path) into a base64 data URI
def image_to_base64_data_uri(image_input):
    if isinstance(image_input, str):
        with open(image_input, "rb") as img_file:
            base64_data = base64.b64encode(img_file.read()).decode('utf-8')
    elif isinstance(image_input, Image.Image):
        buffer = io.BytesIO()
        image_input.save(buffer, format="PNG")
        base64_data = base64.b64encode(buffer.getvalue()).decode('utf-8')
    else:
        raise ValueError("Unsupported input type. Input must be a file path or a PIL.Image.Image instance.")
    return f"data:image/png;base64,{base64_data}"

class Llava:
    def __init__(self, mmproj="model/mmproj-model-f16.gguf", model_path="model/ggml-model-q4_0.gguf", gpu=False):
        chat_handler = Llava15ChatHandler(clip_model_path=mmproj, verbose=True)
        n_gpu_layers = 0
        if gpu:
            n_gpu_layers = -1
        self.llm = Llama(model_path=model_path, chat_handler=chat_handler, n_ctx=2048, logits_all=True, n_gpu_layers=n_gpu_layers)

    def run_inference(self, image, prompt):
        data_uri = image_to_base64_data_uri(image)
        res = self.llm.create_chat_completion(
            messages=[
                {"role": "system", "content": "You are an assistant who perfectly describes images."},
                {
                    "role": "user",
                    "content": [
                        {"type": "image_url", "image_url": {"url": data_uri}},
                        {"type": "text", "text": prompt}
                    ]
                }
            ]
        )
        return res["choices"][0]["message"]["content"]

llm_model = Llava()

def predict(image, prompt):
    result = llm_model.run_inference(image, prompt)
    return result

title_and_links_markdown = """
# 🛸SpaceLLaVA🌋: A spatial reasoning multi-modal model
This space hosts our initial release of LLaVA 1.5 LoRA tuned for spatial reasoning using data generated with [VQASynth](https://github.com/remyxai/VQASynth).
Upload an image and ask a question.

[Model](https://huggingface.co/remyxai/SpaceLLaVA) | [Code](https://github.com/remyxai/VQASynth) | [Paper](https://spatial-vlm.github.io)
"""

examples = [
    ["examples/warehouse_1.jpg", "Is the man wearing gray pants to the left of the pile of boxes on a pallet?"],
    ["examples/warehouse_2.jpg", "Is the forklift taller than the shelves of boxes?"],
]


# Create the Gradio interface with the markdown
iface = gr.Interface(
    fn=predict,
    inputs=[gr.Image(type="pil", label="Input Image"), gr.Textbox(label="Prompt")],
    outputs=gr.Textbox(),
    examples=examples,
    title="🛸SpaceLLaVA🌋: A spatial reasoning multi-modal model",
    description=title_and_links_markdown  # Use description for markdown
)

# Launch the Gradio app
iface.launch()