Spaces:
Running
on
Zero
Running
on
Zero
smellslikeml
commited on
Commit
·
22fc8c6
1
Parent(s):
013eddf
Add application file
Browse files- app.py +110 -0
- requirements.txt +5 -0
app.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""SpaceLlama3.1 demo gradio app."""
|
2 |
+
|
3 |
+
"""SpaceLlama3.1 demo gradio app."""
|
4 |
+
|
5 |
+
import datetime
|
6 |
+
import logging
|
7 |
+
import os
|
8 |
+
|
9 |
+
import gradio as gr
|
10 |
+
import requests
|
11 |
+
import torch
|
12 |
+
import PIL.Image
|
13 |
+
from prismatic import load
|
14 |
+
|
15 |
+
INTRO_TEXT = """SpaceLlama3.1 demo\n\n
|
16 |
+
| [Model](https://huggingface.co/remyxai/SpaceLlama3.1)
|
17 |
+
| [GitHub](https://github.com/remyxai/VQASynth/tree/main)
|
18 |
+
| [Demo](https://huggingface.co/spaces/remyxai/SpaceLlama3.1)
|
19 |
+
| [Discord](https://discord.gg/DAy3P5wYJk)
|
20 |
+
\n\n
|
21 |
+
**This is an experimental research model.** Make sure to add appropriate guardrails when using the model for applications.
|
22 |
+
"""
|
23 |
+
|
24 |
+
def compute(image, prompt, model_location):
|
25 |
+
"""Runs model inference."""
|
26 |
+
if image is None:
|
27 |
+
raise gr.Error("Image required")
|
28 |
+
|
29 |
+
logging.info('prompt="%s"', prompt)
|
30 |
+
|
31 |
+
# Open the image file
|
32 |
+
if isinstance(image, str):
|
33 |
+
image = PIL.Image.open(image).convert("RGB")
|
34 |
+
|
35 |
+
# Set device and load the model
|
36 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
37 |
+
vlm = load(model_location)
|
38 |
+
vlm.to(device, dtype=torch.bfloat16)
|
39 |
+
|
40 |
+
# Prepare prompt
|
41 |
+
prompt_builder = vlm.get_prompt_builder()
|
42 |
+
prompt_builder.add_turn(role="human", message=prompt)
|
43 |
+
prompt_text = prompt_builder.get_prompt()
|
44 |
+
|
45 |
+
# Generate the text based on image and prompt
|
46 |
+
generated_text = vlm.generate(
|
47 |
+
image,
|
48 |
+
prompt_text,
|
49 |
+
do_sample=True,
|
50 |
+
temperature=0.1,
|
51 |
+
max_new_tokens=512,
|
52 |
+
min_length=1,
|
53 |
+
)
|
54 |
+
output = generated_text.split("</s>")[0]
|
55 |
+
|
56 |
+
logging.info('output="%s"', output)
|
57 |
+
|
58 |
+
return output
|
59 |
+
|
60 |
+
def reset():
|
61 |
+
"""Resets the input fields."""
|
62 |
+
return "", None
|
63 |
+
|
64 |
+
def create_app():
|
65 |
+
"""Creates demo UI."""
|
66 |
+
|
67 |
+
with gr.Blocks() as demo:
|
68 |
+
# Main UI structure
|
69 |
+
gr.Markdown(INTRO_TEXT)
|
70 |
+
with gr.Row():
|
71 |
+
image = gr.Image(value=None, label="Image", type="filepath", visible=True) # input
|
72 |
+
with gr.Column():
|
73 |
+
prompt = gr.Textbox(value="", label="Prompt", visible=True)
|
74 |
+
model_info = gr.Markdown(label="Model Info")
|
75 |
+
run = gr.Button("Run", variant="primary")
|
76 |
+
clear = gr.Button("Clear")
|
77 |
+
highlighted_text = gr.HighlightedText(value="", label="Output", visible=True)
|
78 |
+
|
79 |
+
# Model location
|
80 |
+
model_location = "remyxai/SpaceLlama3.1" # Update as needed
|
81 |
+
|
82 |
+
# Button event handlers
|
83 |
+
run.click(
|
84 |
+
compute,
|
85 |
+
[image, prompt, model_location],
|
86 |
+
highlighted_text,
|
87 |
+
)
|
88 |
+
clear.click(reset, None, [prompt, image])
|
89 |
+
|
90 |
+
# Status
|
91 |
+
status = gr.Markdown(f"Startup: {datetime.datetime.now()}")
|
92 |
+
gpu_kind = gr.Markdown(f"GPU=?")
|
93 |
+
demo.load(
|
94 |
+
lambda: [f"Model `{model_location}` loaded."],
|
95 |
+
None,
|
96 |
+
model_info,
|
97 |
+
)
|
98 |
+
|
99 |
+
return demo
|
100 |
+
|
101 |
+
if __name__ == "__main__":
|
102 |
+
|
103 |
+
logging.basicConfig(
|
104 |
+
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
105 |
+
)
|
106 |
+
|
107 |
+
for k, v in os.environ.items():
|
108 |
+
logging.info('environ["%s"] = %r', k, v)
|
109 |
+
|
110 |
+
create_app().queue().launch()
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
Pillow
|
3 |
+
torch
|
4 |
+
requests
|
5 |
+
git+https://github.com/remyxai/prismatic-vlms.git
|