import gradio as gr from fastapi import FastAPI, HTTPException, Request from pydantic import BaseModel from transformers import AutoModelForCausalLM, AutoTokenizer import torch import uvicorn from fastapi.responses import HTMLResponse # Khởi tạo FastAPI app = FastAPI() # Tải model và tokenizer khi ứng dụng khởi động model_name = "Qwen/Qwen2.5-0.5B" try: tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype="auto", device_map="auto", attn_implementation="eager" # Tránh cảnh báo sdpa ) print("Model and tokenizer loaded successfully!") except Exception as e: print(f"Error loading model: {e}") raise # Định nghĩa request body cho API class TextInput(BaseModel): prompt: str max_length: int = 100 # Hàm sinh văn bản (dùng chung cho API và Gradio) def generate_text(prompt, max_length=100): try: inputs = tokenizer(prompt, return_tensors="pt").to(model.device) outputs = model.generate( inputs["input_ids"], max_length=max_length, num_return_sequences=1, no_repeat_ngram_size=2, do_sample=True, top_k=50, top_p=0.95 ) return tokenizer.decode(outputs[0], skip_special_tokens=True) except Exception as e: raise Exception(f"Error: {str(e)}") # API endpoint để sinh văn bản @app.post("/generate") async def generate_text_api(input: TextInput): try: result = generate_text(input.prompt, input.max_length) return {"generated_text": result} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) # API endpoint kiểm tra sức khỏe @app.get("/") async def root(): return {"message": "Qwen2.5-0.5B API is running!"} # API endpoint hiển thị URL @app.get("/api_link") async def get_api_link(request: Request): scheme = request.url.scheme host = request.url.hostname if request.url.port: base_url = f"{scheme}://{host}:{request.url.port}" else: base_url = f"{scheme}://{host}" return { "api_url": base_url, "endpoints": { "health_check": f"{base_url}/", "generate_text": f"{base_url}/generate", "api_link": f"{base_url}/api_link", "interface": f"{base_url}/interface" } } # Tạo giao diện Gradio def create_gradio_interface(): with gr.Blocks(title="Qwen2.5-0.5B Text Generator") as demo: gr.Markdown("# Qwen2.5-0.5B Text Generator") gr.Markdown("Enter a prompt and get generated text!") with gr.Row(): prompt_input = gr.Textbox(label="Prompt", placeholder="Type something...") max_length_input = gr.Slider(50, 500, value=100, step=10, label="Max Length") generate_button = gr.Button("Generate") output_text = gr.Textbox(label="Generated Text", interactive=False) generate_button.click( fn=generate_text, inputs=[prompt_input, max_length_input], outputs=output_text ) return demo # Thêm endpoint để hiển thị giao diện Gradio @app.get("/interface", response_class=HTMLResponse) async def gradio_interface(request: Request): gradio_app = create_gradio_interface() return HTMLResponse(content=gradio_app.render()) # Chạy ứng dụng nếu không trên Hugging Face Spaces if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)