fashion-clip / app.py
Saad0KH's picture
Update app.py
f8f1b61 verified
from flask import Flask, request, jsonify
from transformers import AutoProcessor, CLIPModel
from PIL import Image
import base64
import io
# Charger le modèle CLIP et le processeur
model = CLIPModel.from_pretrained("patrickjohncyh/fashion-clip")
processor = AutoProcessor.from_pretrained("patrickjohncyh/fashion-clip")
# Créer une instance Flask
app = Flask(__name__)
# Fonction pour la classification d'image avec du texte en entrée
def classify_image_with_text(text, image):
keywords = text.split(',')
image = decode_image_from_base64(image)
inputs = processor(text=keywords, images=image, return_tensors="pt", padding=True)
outputs = model(**inputs)
logits_per_image = outputs.logits_per_image # score de similarité image-texte
probs = logits_per_image.softmax(dim=1)
predicted_class_index = probs.argmax(dim=1).item()
predicted_label = keywords[predicted_class_index]
return predicted_label
# Fonction pour la classification d'image avec des propriétés et options
def classify_image_with_properties(properties, image):
image = decode_image_from_base64(image)
result = []
for prop in properties:
property_name = prop['property']
options = prop['options']
keywords = options.split(',')
# Effectuer la classification pour chaque ensemble propriété-options
inputs = processor(text=keywords, images=image, return_tensors="pt", padding=True)
outputs = model(**inputs)
logits_per_image = outputs.logits_per_image
probs = logits_per_image.softmax(dim=1)
predicted_class_index = probs.argmax(dim=1).item()
# Obtenir l'option complète correspondant à l'indice prédit
predicted_label = keywords[predicted_class_index]
result.append({ "property": property_name, "value": predicted_label })
return result
# Fonction pour décoder une image encodée en base64 en objet PIL.Image.Image
def decode_image_from_base64(image_data):
image_data = base64.b64decode(image_data)
image = Image.open(io.BytesIO(image_data))
return image
@app.get("/")
def root():
return "Welcome to the Fashion Clip API!"
# Route pour l'API REST de classification simple
@app.route('/api/classify', methods=['POST'])
def classify():
data = request.json
text = data['text']
image = data['image']
result = classify_image_with_text(text, image)
return jsonify({'result': result})
# Route pour l'API REST de classification avec propriétés et options
@app.route('/api/classify-properties', methods=['POST'])
def classify_properties():
data = request.json
properties = data['properties']
image = data['image']
result = classify_image_with_properties(properties, image)
return jsonify({'result': result})
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860)