File size: 4,142 Bytes
9a76509 45d388a 59cd1f7 9a46859 59cd1f7 45d388a ac716aa 9a46859 45d388a 9a46859 45d388a 9a46859 9a76509 45d388a 9a46859 45d388a 9a46859 9a76509 9a46859 45d388a ac716aa 1f5a613 ac716aa 1f5a613 ac716aa 1f5a613 ac716aa 59cd1f7 ac716aa 9f3a452 9a76509 ac716aa 9a46859 ac716aa 9a46859 ac716aa 9a46859 ac716aa 8a73b48 9a46859 ac716aa 9a46859 ac716aa 9a46859 ac716aa 9a76509 8a73b48 9a76509 ac716aa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
# Tải model và tokenizer khi ứng dụng khởi động
model_name = "Qwen/Qwen2.5-0.5B"
try:
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Đặt pad_token_id nếu chưa có
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
device_map="auto",
attn_implementation="eager" # Tránh cảnh báo sdpa
)
print("Model and tokenizer loaded successfully!")
except Exception as e:
print(f"Error loading model: {e}")
raise
# Hàm sinh văn bản (dùng cho cả UI và API)
def generate_text(prompt, max_length, state):
try:
# Mã hóa đầu vào với attention_mask
inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).to(model.device)
outputs = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_length=max_length,
num_return_sequences=1,
no_repeat_ngram_size=2,
do_sample=True,
top_k=50,
top_p=0.95,
pad_token_id=tokenizer.pad_token_id
)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Cập nhật state với kết quả mới
state.append(generated_text)
return state, generated_text # Trả về state và output để hiển thị
except Exception as e:
error_msg = f"Error: {str(e)}"
state.append(error_msg)
return state, error_msg
# Hàm hiển thị thông tin API
def get_api_info():
base_url = "https://<your-space-name>.hf.space"
return (
"Welcome to Qwen2.5-0.5B API!\n"
f"API Base URL: {base_url} (Replace '<your-space-name>' with your actual Space name)\n"
"Endpoints:\n"
f"- GET {base_url}/api/health_check (Check API status)\n"
f"- POST {base_url}/api/generate (Generate text)\n"
"To use the generate API, send a POST request with JSON:\n"
'{"0": "your prompt", "1": 150}'
)
# Hàm kiểm tra sức khỏe (dành cho API)
def health_check():
return "Qwen2.5-0.5B API is running!"
# Tạo giao diện Gradio
with gr.Blocks(title="Qwen2.5-0.5B Text Generator") as demo:
gr.Markdown("# Qwen2.5-0.5B Text Generator")
gr.Markdown("Enter a prompt below or use the API!")
# State để lưu trữ lịch sử kết quả
state = gr.State(value=[]) # Khởi tạo state là danh sách rỗng
# Hiển thị thông tin API
gr.Markdown("### API Information")
api_info = gr.Textbox(label="API Details", value=get_api_info(), interactive=False)
# Giao diện sinh văn bản
gr.Markdown("### Generate Text")
with gr.Row():
prompt_input = gr.Textbox(label="Prompt", placeholder="Type something...")
max_length_input = gr.Slider(50, 500, value=100, step=10, label="Max Length")
generate_button = gr.Button("Generate")
output_text = gr.Textbox(label="Generated Text History", interactive=False, lines=10)
# Liên kết button với hàm generate_text
generate_button.click(
fn=generate_text,
inputs=[prompt_input, max_length_input, state],
outputs=[state, output_text] # Cập nhật cả state và output_text
)
# Định nghĩa API endpoints với Gradio
interface = gr.Interface(
fn=lambda prompt, max_length: generate_text(prompt, max_length, [])[1], # Chỉ lấy output, không dùng state cho API
inputs=["text", "number"],
outputs="text",
title="Qwen2.5-0.5B API",
api_name="/generate"
).queue()
health_interface = gr.Interface(
fn=health_check,
inputs=None,
outputs="text",
api_name="/health_check"
)
# Gắn các interface vào demo
demo = gr.TabbedInterface([interface, health_interface], ["Generate Text", "Health Check"])
# Chạy ứng dụng
demo.launch(server_name="0.0.0.0", server_port=7860) |