File size: 3,303 Bytes
9a76509 45d388a 59cd1f7 45d388a ac716aa 9a76509 45d388a 9a76509 45d388a 9a76509 45d388a 9a76509 ac716aa 45d388a ac716aa 1f5a613 ac716aa 1f5a613 ac716aa 1f5a613 ac716aa 59cd1f7 ac716aa 9f3a452 9a76509 ac716aa 8a73b48 ac716aa 9a76509 8a73b48 9a76509 ac716aa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
# Tải model và tokenizer khi ứng dụng khởi động
model_name = "Qwen/Qwen2.5-0.5B"
try:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
device_map="auto",
attn_implementation="eager" # Tránh cảnh báo sdpa
)
print("Model and tokenizer loaded successfully!")
except Exception as e:
print(f"Error loading model: {e}")
raise
# Hàm sinh văn bản (dùng cho cả UI và API)
def generate_text(prompt, max_length=100):
try:
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
inputs["input_ids"],
max_length=max_length,
num_return_sequences=1,
no_repeat_ngram_size=2,
do_sample=True,
top_k=50,
top_p=0.95
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
except Exception as e:
return f"Error: {str(e)}"
# Hàm hiển thị thông tin API
def get_api_info():
base_url = "https://<your-space-name>.hf.space"
return (
"Welcome to Qwen2.5-0.5B API!\n"
f"API Base URL: {base_url} (Replace '<your-space-name>' with your actual Space name)\n"
"Endpoints:\n"
f"- GET {base_url}/api/health_check (Check API status)\n"
f"- POST {base_url}/api/generate (Generate text)\n"
"To use the generate API, send a POST request with JSON:\n"
'{"0": "your prompt", "1": 150}'
)
# Hàm kiểm tra sức khỏe (dành cho API)
def health_check():
return "Qwen2.5-0.5B API is running!"
# Tạo giao diện Gradio
with gr.Blocks(title="Qwen2.5-0.5B Text Generator") as demo:
gr.Markdown("# Qwen2.5-0.5B Text Generator")
gr.Markdown("Enter a prompt below or use the API!")
# Hiển thị thông tin API
gr.Markdown("### API Information")
api_info = gr.Textbox(label="API Details", value=get_api_info(), interactive=False)
# Giao diện sinh văn bản
gr.Markdown("### Generate Text")
with gr.Row():
prompt_input = gr.Textbox(label="Prompt", placeholder="Type something...")
max_length_input = gr.Slider(50, 500, value=100, step=10, label="Max Length")
generate_button = gr.Button("Generate")
output_text = gr.Textbox(label="Generated Text", interactive=False)
# Liên kết button với hàm generate_text
generate_button.click(
fn=generate_text,
inputs=[prompt_input, max_length_input],
outputs=output_text
)
# Định nghĩa API endpoints với Gradio
interface = gr.Interface(
fn=generate_text,
inputs=["text", "number"],
outputs="text",
title="Qwen2.5-0.5B API",
api_name="/generate" # API endpoint: /api/generate
).queue()
health_interface = gr.Interface(
fn=health_check,
inputs=None,
outputs="text",
api_name="/health_check" # API endpoint: /api/health_check
)
# Gắn các interface vào demo
demo = gr.TabbedInterface([interface, health_interface], ["Generate Text", "Health Check"])
# Chạy ứng dụng
demo.launch(server_name="0.0.0.0", server_port=7860) |