import gradio as gr from PIL import Image import io import base64 from huggingface_hub import InferenceClient # Initialize the Hugging Face Inference Client client = InferenceClient("microsoft/llava-med-7b-delta") # Custom Field for Base64 Encoded Image class Base64ImageField(gr.Field): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def preprocess(self, image): buffered = io.BytesIO() image.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()).decode('utf-8') return img_str # Function to interact with LLAVA model def respond( message, history: list[tuple[str, str]], system_message, max_tokens, temperature, top_p, image=None ): messages = [{"role": "system", "content": system_message}] for val in history: if val[0]: messages.append({"role": "user", "content": val[0]}) if val[1]: messages.append({"role": "assistant", "content": val[1]}) messages.append({"role": "user", "content": message}) if image: # Convert image(s) to base64 using the custom field if isinstance(image, Image.Image): image_b64 = Base64ImageField().preprocess(image) messages.append({"role": "user", "content": "Image uploaded", "image": image_b64}) else: for img in image: image_b64 = Base64ImageField().preprocess(img) messages.append({"role": "user", "content": "Image uploaded", "image": image_b64}) # Call Hugging Face model for response try: responses = [] generated_image = None for response in client.chat_completion( messages, max_tokens=max_tokens, stream=True, temperature=temperature, top_p=top_p, ): token = response.choices[0].delta.content responses.append(token) # Check if the response contains an image to be displayed if response.choices[0].delta.image: image_b64 = response.choices[0].delta.image image_data = base64.b64decode(image_b64) generated_image = Image.open(io.BytesIO(image_data)) # Optionally convert to RGB if needed # generated_image = generated_image.convert("RGB") yield responses, generated_image except Exception as e: yield [str(e)], None # Debugging print statements print("Starting Gradio interface setup...") try: # Create a Gradio interface demo = gr.Interface( fn=respond, inputs=[ gr.Textbox(label="Message"), gr.Image(label="Upload Medical Image (Optional)", type="pil") ], outputs=[ gr.Textbox(label="Response", placeholder="Model response will appear here..."), gr.Image(label="Generated Image", type="pil", output=True) ], title="LLAVA Model - Medical Image and Question", description="Upload a medical image and ask a specific question about the image for a medical description.", additional_inputs=[ gr.Textbox(label="System message", value="You are a friendly Chatbot."), gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"), gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)") ] ) # Launch the Gradio interface if __name__ == "__main__": print("Launching Gradio interface...") demo.launch() except Exception as e: print(f"Error during Gradio setup: {str(e)}")