Spaces:
Sleeping
Sleeping
from flask import Flask, request, jsonify | |
from transformers import AutoProcessor, CLIPModel | |
from PIL import Image | |
import base64 | |
import io | |
# Charger le modèle CLIP et le processeur | |
model = CLIPModel.from_pretrained("patrickjohncyh/fashion-clip") | |
processor = AutoProcessor.from_pretrained("patrickjohncyh/fashion-clip") | |
# Créer une instance Flask | |
app = Flask(__name__) | |
# Fonction pour la classification d'image avec du texte en entrée | |
def classify_image_with_text(text, image): | |
keywords = text.split(',') | |
image = decode_image_from_base64(image) | |
inputs = processor(text=keywords, images=image, return_tensors="pt", padding=True) | |
outputs = model(**inputs) | |
logits_per_image = outputs.logits_per_image # score de similarité image-texte | |
probs = logits_per_image.softmax(dim=1) | |
predicted_class_index = probs.argmax(dim=1).item() | |
predicted_label = keywords[predicted_class_index] | |
return predicted_label | |
# Fonction pour la classification d'image avec des propriétés et options | |
def classify_image_with_properties(properties, image): | |
image = decode_image_from_base64(image) | |
result = [] | |
for prop in properties: | |
property_name = prop['property'] | |
options = prop['options'] | |
keywords = options.split(',') | |
# Effectuer la classification pour chaque ensemble propriété-options | |
inputs = processor(text=keywords, images=image, return_tensors="pt", padding=True) | |
outputs = model(**inputs) | |
logits_per_image = outputs.logits_per_image | |
probs = logits_per_image.softmax(dim=1) | |
predicted_class_index = probs.argmax(dim=1).item() | |
# Obtenir l'option complète correspondant à l'indice prédit | |
predicted_label = keywords[predicted_class_index] | |
result.append({ "property": property_name, "value": predicted_label }) | |
return result | |
# Fonction pour décoder une image encodée en base64 en objet PIL.Image.Image | |
def decode_image_from_base64(image_data): | |
image_data = base64.b64decode(image_data) | |
image = Image.open(io.BytesIO(image_data)) | |
return image | |
def root(): | |
return "Welcome to the Fashion Clip API!" | |
# Route pour l'API REST de classification simple | |
def classify(): | |
data = request.json | |
text = data['text'] | |
image = data['image'] | |
result = classify_image_with_text(text, image) | |
return jsonify({'result': result}) | |
# Route pour l'API REST de classification avec propriétés et options | |
def classify_properties(): | |
data = request.json | |
properties = data['properties'] | |
image = data['image'] | |
result = classify_image_with_properties(properties, image) | |
return jsonify({'result': result}) | |
if __name__ == "__main__": | |
app.run(host="0.0.0.0", port=7860) | |