AIAPI / app.py
rapacious's picture
Update app.py
8a73b48 verified
raw
history blame
3.3 kB
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
# Tải model và tokenizer khi ứng dụng khởi động
model_name = "Qwen/Qwen2.5-0.5B"
try:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
device_map="auto",
attn_implementation="eager" # Tránh cảnh báo sdpa
)
print("Model and tokenizer loaded successfully!")
except Exception as e:
print(f"Error loading model: {e}")
raise
# Hàm sinh văn bản (dùng cho cả UI và API)
def generate_text(prompt, max_length=100):
try:
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
inputs["input_ids"],
max_length=max_length,
num_return_sequences=1,
no_repeat_ngram_size=2,
do_sample=True,
top_k=50,
top_p=0.95
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
except Exception as e:
return f"Error: {str(e)}"
# Hàm hiển thị thông tin API
def get_api_info():
base_url = "https://<your-space-name>.hf.space"
return (
"Welcome to Qwen2.5-0.5B API!\n"
f"API Base URL: {base_url} (Replace '<your-space-name>' with your actual Space name)\n"
"Endpoints:\n"
f"- GET {base_url}/api/health_check (Check API status)\n"
f"- POST {base_url}/api/generate (Generate text)\n"
"To use the generate API, send a POST request with JSON:\n"
'{"0": "your prompt", "1": 150}'
)
# Hàm kiểm tra sức khỏe (dành cho API)
def health_check():
return "Qwen2.5-0.5B API is running!"
# Tạo giao diện Gradio
with gr.Blocks(title="Qwen2.5-0.5B Text Generator") as demo:
gr.Markdown("# Qwen2.5-0.5B Text Generator")
gr.Markdown("Enter a prompt below or use the API!")
# Hiển thị thông tin API
gr.Markdown("### API Information")
api_info = gr.Textbox(label="API Details", value=get_api_info(), interactive=False)
# Giao diện sinh văn bản
gr.Markdown("### Generate Text")
with gr.Row():
prompt_input = gr.Textbox(label="Prompt", placeholder="Type something...")
max_length_input = gr.Slider(50, 500, value=100, step=10, label="Max Length")
generate_button = gr.Button("Generate")
output_text = gr.Textbox(label="Generated Text", interactive=False)
# Liên kết button với hàm generate_text
generate_button.click(
fn=generate_text,
inputs=[prompt_input, max_length_input],
outputs=output_text
)
# Định nghĩa API endpoints với Gradio
interface = gr.Interface(
fn=generate_text,
inputs=["text", "number"],
outputs="text",
title="Qwen2.5-0.5B API",
api_name="/generate" # API endpoint: /api/generate
).queue()
health_interface = gr.Interface(
fn=health_check,
inputs=None,
outputs="text",
api_name="/health_check" # API endpoint: /api/health_check
)
# Gắn các interface vào demo
demo = gr.TabbedInterface([interface, health_interface], ["Generate Text", "Health Check"])
# Chạy ứng dụng
demo.launch(server_name="0.0.0.0", server_port=7860)