File size: 3,303 Bytes
9a76509
45d388a
 
 
 
 
59cd1f7
 
 
 
 
 
 
 
 
 
 
 
45d388a
ac716aa
9a76509
45d388a
9a76509
45d388a
 
9a76509
45d388a
 
 
 
 
 
9a76509
 
ac716aa
45d388a
ac716aa
 
1f5a613
ac716aa
 
1f5a613
ac716aa
 
 
 
1f5a613
ac716aa
59cd1f7
ac716aa
 
 
9f3a452
9a76509
ac716aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8a73b48
ac716aa
 
 
 
 
 
 
 
 
 
 
 
 
9a76509
8a73b48
 
9a76509
ac716aa
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

# Tải model và tokenizer khi ứng dụng khởi động
model_name = "Qwen/Qwen2.5-0.5B"
try:
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype="auto",
        device_map="auto",
        attn_implementation="eager"  # Tránh cảnh báo sdpa
    )
    print("Model and tokenizer loaded successfully!")
except Exception as e:
    print(f"Error loading model: {e}")
    raise

# Hàm sinh văn bản (dùng cho cả UI và API)
def generate_text(prompt, max_length=100):
    try:
        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
        outputs = model.generate(
            inputs["input_ids"],
            max_length=max_length,
            num_return_sequences=1,
            no_repeat_ngram_size=2,
            do_sample=True,
            top_k=50,
            top_p=0.95
        )
        return tokenizer.decode(outputs[0], skip_special_tokens=True)
    except Exception as e:
        return f"Error: {str(e)}"

# Hàm hiển thị thông tin API
def get_api_info():
    base_url = "https://<your-space-name>.hf.space"
    return (
        "Welcome to Qwen2.5-0.5B API!\n"
        f"API Base URL: {base_url} (Replace '<your-space-name>' with your actual Space name)\n"
        "Endpoints:\n"
        f"- GET {base_url}/api/health_check (Check API status)\n"
        f"- POST {base_url}/api/generate (Generate text)\n"
        "To use the generate API, send a POST request with JSON:\n"
        '{"0": "your prompt", "1": 150}'
    )

# Hàm kiểm tra sức khỏe (dành cho API)
def health_check():
    return "Qwen2.5-0.5B API is running!"

# Tạo giao diện Gradio
with gr.Blocks(title="Qwen2.5-0.5B Text Generator") as demo:
    gr.Markdown("# Qwen2.5-0.5B Text Generator")
    gr.Markdown("Enter a prompt below or use the API!")
    
    # Hiển thị thông tin API
    gr.Markdown("### API Information")
    api_info = gr.Textbox(label="API Details", value=get_api_info(), interactive=False)
    
    # Giao diện sinh văn bản
    gr.Markdown("### Generate Text")
    with gr.Row():
        prompt_input = gr.Textbox(label="Prompt", placeholder="Type something...")
        max_length_input = gr.Slider(50, 500, value=100, step=10, label="Max Length")
    
    generate_button = gr.Button("Generate")
    output_text = gr.Textbox(label="Generated Text", interactive=False)
    
    # Liên kết button với hàm generate_text
    generate_button.click(
        fn=generate_text,
        inputs=[prompt_input, max_length_input],
        outputs=output_text
    )

# Định nghĩa API endpoints với Gradio
interface = gr.Interface(
    fn=generate_text,
    inputs=["text", "number"],
    outputs="text",
    title="Qwen2.5-0.5B API",
    api_name="/generate"  # API endpoint: /api/generate
).queue()

health_interface = gr.Interface(
    fn=health_check,
    inputs=None,
    outputs="text",
    api_name="/health_check"  # API endpoint: /api/health_check
)

# Gắn các interface vào demo
demo = gr.TabbedInterface([interface, health_interface], ["Generate Text", "Health Check"])

# Chạy ứng dụng
demo.launch(server_name="0.0.0.0", server_port=7860)