import gradio as gr from transformers import AutoProcessor, AutoModelForVision2Seq, TextIteratorStreamer from threading import Thread import re import time from PIL import Image import torch # Check for GPU availability device = "cuda" if torch.cuda.is_available() else "cpu" # Load model and processor processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-Instruct") model = AutoModelForVision2Seq.from_pretrained( "HuggingFaceTB/SmolVLM-Instruct", torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32, device_map="auto" if device == "cpu" else None # Automatically maps to CPU if no GPU ).to(device) # Inference function def model_inference( input_dict, history, decoding_strategy, temperature, max_new_tokens, repetition_penalty, top_p ): text = input_dict["text"] if len(input_dict["files"]) > 1: images = [Image.open(image).convert("RGB") for image in input_dict["files"]] elif len(input_dict["files"]) == 1: images = [Image.open(input_dict["files"][0]).convert("RGB")] else: gr.Error("Please input a query and optionally image(s).") if text == "" and images: gr.Error("Please input a text query along with the image(s).") resulting_messages = [ { "role": "user", "content": [{"type": "image"} for _ in range(len(images))] + [ {"type": "text", "text": text} ] } ] prompt = processor.apply_chat_template(resulting_messages, add_generation_prompt=True) inputs = processor(text=prompt, images=[images], return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} generation_args = { "max_new_tokens": max_new_tokens, "repetition_penalty": repetition_penalty, } assert decoding_strategy in ["Greedy", "Top P Sampling"] if decoding_strategy == "Greedy": generation_args["do_sample"] = False elif decoding_strategy == "Top P Sampling": generation_args["temperature"] = temperature generation_args["do_sample"] = True generation_args["top_p"] = top_p generation_args.update(inputs) # Stream generation streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True) generation_args = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens) generated_text = "" thread = Thread(target=model.generate, kwargs=generation_args) thread.start() thread.join() buffer = "" for new_text in streamer: buffer += new_text yield buffer # Gradio interface demo = gr.ChatInterface( fn=model_inference, title="Geoscience AI Interpreter", description=( "This app interprets thin sections, seismic images, etc. " "Upload an image and a text query. Works best with single-turn conversations. " "Clear the conversation after a single turn." ), textbox=gr.MultimodalTextbox( label="Query Input", file_types=["image"], file_count="multiple" ), stop_btn="Stop Generation", multimodal=True, additional_inputs=[ gr.Radio( ["Top P Sampling", "Greedy"], value="Greedy", label="Decoding strategy", info="Higher values are equivalent to sampling more low-probability tokens.", ), gr.Slider( minimum=0.0, maximum=5.0, value=0.4, step=0.1, interactive=True, label="Sampling temperature", info="Higher values produce more diverse outputs.", ), gr.Slider( minimum=8, maximum=1024, value=512, step=1, interactive=True, label="Maximum number of new tokens to generate", ), gr.Slider( minimum=0.01, maximum=5.0, value=1.2, step=0.01, interactive=True, label="Repetition penalty", info="1.0 is equivalent to no penalty.", ), gr.Slider( minimum=0.01, maximum=0.99, value=0.8, step=0.01, interactive=True, label="Top P", info="Higher values are equivalent to sampling more low-probability tokens.", ), ], cache_examples=False, ) # Launch Gradio app demo.launch(debug=True)