Spaces:
Running
on
Zero
Running
on
Zero
"""SpaceLlama3.1 demo gradio app.""" | |
import datetime | |
import logging | |
import os | |
import gradio as gr | |
import torch | |
import PIL.Image | |
from prismatic import load | |
from huggingface_hub import login | |
# Authenticate with the Hugging Face Hub | |
def authenticate_huggingface(): | |
hf_token = os.getenv("HF_TOKEN") | |
if hf_token: | |
login(token=hf_token) | |
else: | |
raise ValueError("Hugging Face API token not found. Please set it as an environment variable named 'HF_TOKEN'.") | |
# Call the authentication function once at the start | |
authenticate_huggingface() | |
INTRO_TEXT = """SpaceLlama3.1 demo\n\n | |
| [Model](https://huggingface.co/remyxai/SpaceLlama3.1) | |
| [GitHub](https://github.com/remyxai/VQASynth/tree/main) | |
| [Demo](https://huggingface.co/spaces/remyxai/SpaceLlama3.1) | |
| [Discord](https://discord.gg/DAy3P5wYJk) | |
\n\n | |
**This is an experimental research model.** Make sure to add appropriate guardrails when using the model for applications. | |
""" | |
# Set model location as a constant outside the function | |
MODEL_LOCATION = "remyxai/SpaceLlama3.1" # Update as needed | |
def compute(image, prompt): | |
"""Runs model inference.""" | |
if image is None: | |
raise gr.Error("Image required") | |
logging.info('prompt="%s"', prompt) | |
# Open the image file | |
if isinstance(image, str): | |
image = PIL.Image.open(image).convert("RGB") | |
# Set device and load the model | |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
vlm = load(MODEL_LOCATION) # Use the constant for model location | |
vlm.to(device, dtype=torch.bfloat16) | |
# Prepare prompt | |
prompt_builder = vlm.get_prompt_builder() | |
prompt_builder.add_turn(role="human", message=prompt) | |
prompt_text = prompt_builder.get_prompt() | |
# Generate the text based on image and prompt | |
generated_text = vlm.generate( | |
image, | |
prompt_text, | |
do_sample=True, | |
temperature=0.1, | |
max_new_tokens=512, | |
min_length=1, | |
) | |
output = generated_text.split("</s>")[0] | |
logging.info('output="%s"', output) | |
return output # Ensure that output is a string | |
def reset(): | |
"""Resets the input fields.""" | |
return "", None | |
def create_app(): | |
"""Creates demo UI.""" | |
with gr.Blocks() as demo: | |
# Main UI structure | |
gr.Markdown(INTRO_TEXT) | |
with gr.Row(): | |
image = gr.Image(value=None, label="Image", type="filepath", visible=True) # input | |
with gr.Column(): | |
prompt = gr.Textbox(value="", label="Prompt", visible=True) | |
model_info = gr.Markdown(label="Model Info") | |
run = gr.Button("Run", variant="primary") | |
clear = gr.Button("Clear") | |
highlighted_text = gr.HighlightedText(value="", label="Output", visible=True) | |
# Button event handlers | |
run.click( | |
fn=compute, | |
inputs=[image, prompt], | |
outputs=highlighted_text, # Ensure this is the right output component | |
) | |
clear.click(fn=reset, inputs=None, outputs=[prompt, image]) | |
# Status | |
status = gr.Markdown(f"Startup: {datetime.datetime.now()}") | |
gpu_kind = gr.Markdown(f"GPU=?") | |
demo.load( | |
fn=lambda: f"Model `{MODEL_LOCATION}` loaded.", # Ensure the output is a string | |
inputs=None, | |
outputs=model_info, | |
) | |
return demo | |
if __name__ == "__main__": | |
logging.basicConfig( | |
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" | |
) | |
for k, v in os.environ.items(): | |
logging.info('environ["%s"] = %r', k, v) | |
create_app().queue().launch() | |