Saad0KH's picture
Create app.py
0c4b5a5 verified
raw
history blame
2.62 kB
from flask import Flask, request, jsonify
from PIL import Image
import base64
import io
import time
from threading import Thread
import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration
from transformers import TextIteratorStreamer
model_id = "xtuner/llava-llama-3-8b-v1_1-transformers"
processor = AutoProcessor.from_pretrained(model_id)
model = LlavaForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
)
model.to("cuda:0")
model.generation_config.eos_token_id = 128009
@spaces.GPU
def bot_streaming(text, image):
image = decode_image_from_base64(image)
prompt = f"<|start_header_id|>user<|end_header_id|>\n\n<image>\n{text}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
# print(f"prompt: {prompt}")
image = Image.open(image)
inputs = processor(prompt, image, return_tensors='pt').to(0, torch.float16)
streamer = TextIteratorStreamer(processor, **{"skip_special_tokens": False, "skip_prompt": True})
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024, do_sample=False)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
text_prompt = f"<|start_header_id|>user<|end_header_id|>\n\n{text}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
# print(f"text_prompt: {text_prompt}")
buffer = ""
time.sleep(0.5)
for new_text in streamer:
# find <|eot_id|> and remove it from the new_text
if "<|eot_id|>" in new_text:
new_text = new_text.split("<|eot_id|>")[0]
buffer += new_text
# generated_text_without_prompt = buffer[len(text_prompt):]
generated_text_without_prompt = buffer
# print(generated_text_without_prompt)
time.sleep(0.06)
# print(f"new_text: {generated_text_without_prompt}")
yield generated_text_without_prompt
# CrΓ©er une instance FastAPI
app = Flask(__name__)
# Fonction pour dΓ©coder une image encodΓ©e en base64 en objet PIL.Image.Image
def decode_image_from_base64(image_data):
image_data = base64.b64decode(image_data)
image = Image.open(io.BytesIO(image_data))
return image
@app.get("/")
def root():
return "Welcome to the Llava-extra API!"
# Route pour l'API REST
@app.route('/api/classify', methods=['POST'])
def classify():
data = request.json
print(data)
prompt = data['prompt']
image = data['image']
result = bot_streaming(text, image)
return jsonify({'out': result})
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860)