llava-llama-3-8b / app_api.py
Saad0KH's picture
Rename app.py to app_api.py
d9b3950 verified
raw
history blame
2.09 kB
from flask import Flask, request, Response, jsonify
from threading import Thread
import time
import torch
import base64
import io
from PIL import Image
from transformers import AutoProcessor, AutoTokenizer,LlavaForConditionalGeneration
from transformers import TextIteratorStreamer
import spaces
import gradio as gr
app = Flask(__name__)
model_id = "xtuner/llava-llama-3-8b-v1_1-transformers"
processor = AutoProcessor.from_pretrained(model_id)
model = LlavaForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
)
model.to("cuda:0")
model.generation_config.eos_token_id = 128009
# Fonction pour dΓ©coder une image encodΓ©e en base64 en objet PIL.Image.Image
def decode_image_from_base64(image_data):
image_data = base64.b64decode(image_data)
image = Image.open(io.BytesIO(image_data))
return image
@spaces.GPU
def bot_streaming(text, image):
prompt = f"user\n\n<image>\n{text}assistant\n\n"
image = decode_image_from_base64(image)
inputs = processor(prompt, image, return_tensors='pt').to(0, torch.float16)
streamer = TextIteratorStreamer(processor, **{"skip_special_tokens": False, "skip_prompt": True})
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024, do_sample=False)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
text_prompt = f"user\n\n{text}assistant\n\n"
buffer = ""
time.sleep(0.5)
for new_text in streamer:
if "" in new_text:
new_text = new_text.split("")[0]
buffer += new_text
generated_text_without_prompt = buffer
time.sleep(0.06)
yield generated_text_without_prompt
@app.route('/bot_streaming', methods=['POST'])
def handle_bot_streaming():
data = request.json
prompt = data.get("prompt")
image = data.get("image")
return Response(bot_streaming(prompt, image), mimetype='text/plain')
@app.get("/")
def root():
return "Welcome to the llava 3-8b "
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860)