import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer import torch # Tải model và tokenizer khi ứng dụng khởi động model_name = "Qwen/Qwen2.5-0.5B" try: tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype="auto", device_map="auto", attn_implementation="eager" # Tránh cảnh báo sdpa ) print("Model and tokenizer loaded successfully!") except Exception as e: print(f"Error loading model: {e}") raise # Hàm sinh văn bản (dùng cho cả UI và API) def generate_text(prompt, max_length=100): try: inputs = tokenizer(prompt, return_tensors="pt").to(model.device) outputs = model.generate( inputs["input_ids"], max_length=max_length, num_return_sequences=1, no_repeat_ngram_size=2, do_sample=True, top_k=50, top_p=0.95 ) return tokenizer.decode(outputs[0], skip_special_tokens=True) except Exception as e: return f"Error: {str(e)}" # Hàm hiển thị thông tin API def get_api_info(): # Trên Hugging Face Spaces, API URL sẽ dựa trên tên Space # Khi chạy local, ta giả định port 7860 base_url = "http://localhost:7860" if gr.context.local else "https://.hf.space" return ( "Welcome to Qwen2.5-0.5B API!\n" f"API Base URL: {base_url}\n" "Endpoints:\n" f"- GET {base_url}/api/health_check (Check API status)\n" f"- POST {base_url}/api/generate (Generate text)\n" "To use the generate API, send a POST request with JSON:\n" '{"prompt": "your prompt", "max_length": 150}' ) # Hàm kiểm tra sức khỏe (dành cho API) def health_check(): return "Qwen2.5-0.5B API is running!" # Tạo giao diện Gradio with gr.Blocks(title="Qwen2.5-0.5B Text Generator") as demo: gr.Markdown("# Qwen2.5-0.5B Text Generator") gr.Markdown("Enter a prompt below or use the API!") # Hiển thị thông tin API gr.Markdown("### API Information") api_info = gr.Textbox(label="API Details", value=get_api_info(), interactive=False) # Giao diện sinh văn bản gr.Markdown("### Generate Text") with gr.Row(): prompt_input = gr.Textbox(label="Prompt", placeholder="Type something...") max_length_input = gr.Slider(50, 500, value=100, step=10, label="Max Length") generate_button = gr.Button("Generate") output_text = gr.Textbox(label="Generated Text", interactive=False) # Liên kết button với hàm generate_text generate_button.click( fn=generate_text, inputs=[prompt_input, max_length_input], outputs=output_text ) # Định nghĩa API endpoints với Gradio demo = gr.Interface( fn=generate_text, inputs=["text", "number"], outputs="text", title="Qwen2.5-0.5B API", api_name="/generate" # API endpoint: /api/generate ).queue() # Thêm endpoint health check health_interface = gr.Interface( fn=health_check, inputs=None, outputs="text", api_name="/health_check" # API endpoint: /api/health_check ) # Kết hợp giao diện và API app = gr.mount_gradio_app(demo, health_interface) # Chạy ứng dụng demo.launch(server_name="0.0.0.0", server_port=7860)