import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer import torch # Tải model và tokenizer khi ứng dụng khởi động model_name = "Qwen/Qwen2.5-0.5B" try: tokenizer = AutoTokenizer.from_pretrained(model_name) # Đặt pad_token_id nếu chưa có if tokenizer.pad_token_id is None: tokenizer.pad_token_id = tokenizer.eos_token_id model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype="auto", device_map="auto", attn_implementation="eager" # Tránh cảnh báo sdpa ) print("Model and tokenizer loaded successfully!") except Exception as e: print(f"Error loading model: {e}") raise # Hàm sinh văn bản (dùng cho cả UI và API) def generate_text(prompt, max_length, state): try: # Mã hóa đầu vào với attention_mask inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).to(model.device) outputs = model.generate( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_length=max_length, num_return_sequences=1, no_repeat_ngram_size=2, do_sample=True, top_k=50, top_p=0.95, pad_token_id=tokenizer.pad_token_id ) generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) # Cập nhật state với kết quả mới state.append(generated_text) return state, generated_text # Trả về state và output để hiển thị except Exception as e: error_msg = f"Error: {str(e)}" state.append(error_msg) return state, error_msg # Hàm hiển thị thông tin API def get_api_info(): base_url = "https://.hf.space" return ( "Welcome to Qwen2.5-0.5B API!\n" f"API Base URL: {base_url} (Replace '' with your actual Space name)\n" "Endpoints:\n" f"- GET {base_url}/api/health_check (Check API status)\n" f"- POST {base_url}/api/generate (Generate text)\n" "To use the generate API, send a POST request with JSON:\n" '{"0": "your prompt", "1": 150}' ) # Hàm kiểm tra sức khỏe (dành cho API) def health_check(): return "Qwen2.5-0.5B API is running!" # Tạo giao diện Gradio with gr.Blocks(title="Qwen2.5-0.5B Text Generator") as demo: gr.Markdown("# Qwen2.5-0.5B Text Generator") gr.Markdown("Enter a prompt below or use the API!") # State để lưu trữ lịch sử kết quả state = gr.State(value=[]) # Khởi tạo state là danh sách rỗng # Hiển thị thông tin API gr.Markdown("### API Information") api_info = gr.Textbox(label="API Details", value=get_api_info(), interactive=False) # Giao diện sinh văn bản gr.Markdown("### Generate Text") with gr.Row(): prompt_input = gr.Textbox(label="Prompt", placeholder="Type something...") max_length_input = gr.Slider(50, 500, value=100, step=10, label="Max Length") generate_button = gr.Button("Generate") output_text = gr.Textbox(label="Generated Text History", interactive=False, lines=10) # Liên kết button với hàm generate_text generate_button.click( fn=generate_text, inputs=[prompt_input, max_length_input, state], outputs=[state, output_text] # Cập nhật cả state và output_text ) # Định nghĩa API endpoints với Gradio interface = gr.Interface( fn=lambda prompt, max_length: generate_text(prompt, max_length, [])[1], # Chỉ lấy output, không dùng state cho API inputs=["text", "number"], outputs="text", title="Qwen2.5-0.5B API", api_name="/generate" ).queue() health_interface = gr.Interface( fn=health_check, inputs=None, outputs="text", api_name="/health_check" ) # Gắn các interface vào demo demo = gr.TabbedInterface([interface, health_interface], ["Generate Text", "Health Check"]) # Chạy ứng dụng demo.launch(server_name="0.0.0.0", server_port=7860)