Image_detec / app.py
Exched's picture
Update app.py
5c78d10 verified
raw
history blame
2.51 kB
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, CLIPProcessor, CLIPModel
from PIL import Image
import requests
from io import BytesIO
# Load CLIP model for image classification
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
# Load Mistral-7B-Instruct-v0.3 model for chat
mistral_model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3")
mistral_tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3")
# Function for image classification with CLIP (anime recognition)
def classify_image(input_image):
if isinstance(input_image, str):
response = requests.get(input_image)
img = Image.open(BytesIO(response.content))
else:
img = Image.fromarray(input_image)
# Prepare the image and text (anime-related labels)
inputs = clip_processor(text=["anime", "cartoon", "realistic", "painting"], images=img, return_tensors="pt", padding=True)
outputs = clip_model(**inputs)
logits_per_image = outputs.logits_per_image # this is the image-text similarity score
probs = logits_per_image.softmax(dim=1) # we can apply softmax to get the label probabilities
# Return the predicted class label
labels = ["anime", "cartoon", "realistic", "painting"]
predicted_label = labels[probs.argmax()]
return predicted_label
# Function for chat with Mistral 7B Instruct
def chat_with_mistral(input_text):
inputs = mistral_tokenizer(input_text, return_tensors="pt")
outputs = mistral_model.generate(inputs["input_ids"], max_length=150)
response = mistral_tokenizer.decode(outputs[0], skip_special_tokens=True)
return response
# Create Gradio interface for both Image Classification and Chat
with gr.Blocks() as demo:
with gr.Tab("Chat with Mistral"):
chat_input = gr.Textbox(label="Ask Mistral 7B", placeholder="Type your question here...")
chat_output = gr.Textbox(label="Mistral's Reply", interactive=False)
chat_input.submit(chat_with_mistral, inputs=chat_input, outputs=chat_output)
with gr.Tab("Classify Anime Image"):
img_input = gr.Image(type="numpy", label="Upload Image for Anime Classification")
img_output = gr.Textbox(label="Predicted Label", interactive=False)
img_input.change(classify_image, inputs=img_input, outputs=img_output)
# Launch the interface
demo.launch()