from flask import Flask, request, Response, jsonify from threading import Thread import time import torch import base64 import io from PIL import Image from transformers import AutoProcessor, AutoTokenizer,LlavaForConditionalGeneration from transformers import TextIteratorStreamer import spaces import gradio as gr app = Flask(__name__) model_id = "xtuner/llava-llama-3-8b-v1_1-transformers" processor = AutoProcessor.from_pretrained(model_id) model = LlavaForConditionalGeneration.from_pretrained( model_id, torch_dtype=torch.float16, low_cpu_mem_usage=True, ) model.to("cuda:0") model.generation_config.eos_token_id = 128009 # Fonction pour décoder une image encodée en base64 en objet PIL.Image.Image def decode_image_from_base64(image_data): image_data = base64.b64decode(image_data) image = Image.open(io.BytesIO(image_data)) return image @spaces.GPU def bot_streaming(text, image): prompt = f"user\n\n\n{text}assistant\n\n" image = decode_image_from_base64(image) inputs = processor(prompt, image, return_tensors='pt').to(0, torch.float16) streamer = TextIteratorStreamer(processor, **{"skip_special_tokens": False, "skip_prompt": True}) generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024, do_sample=False) thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() text_prompt = f"user\n\n{text}assistant\n\n" buffer = "" time.sleep(0.5) for new_text in streamer: if "" in new_text: new_text = new_text.split("")[0] buffer += new_text generated_text_without_prompt = buffer time.sleep(0.06) yield generated_text_without_prompt @app.route('/bot_streaming', methods=['POST']) def handle_bot_streaming(): data = request.json prompt = data.get("prompt") image = data.get("image") return Response(bot_streaming(prompt, image), mimetype='text/plain') @app.get("/") def root(): return "Welcome to the llava 3-8b " if __name__ == "__main__": app.run(host="0.0.0.0", port=7860)