File size: 4,142 Bytes
9a76509
45d388a
 
 
 
 
59cd1f7
 
9a46859
 
 
59cd1f7
 
 
 
 
 
 
 
 
 
45d388a
ac716aa
9a46859
45d388a
9a46859
 
45d388a
9a46859
 
9a76509
45d388a
 
 
 
9a46859
 
45d388a
9a46859
 
 
 
9a76509
9a46859
 
 
45d388a
ac716aa
 
1f5a613
ac716aa
 
1f5a613
ac716aa
 
 
 
1f5a613
ac716aa
59cd1f7
ac716aa
 
 
9f3a452
9a76509
ac716aa
 
 
 
9a46859
 
 
ac716aa
 
 
 
 
 
 
 
 
 
 
9a46859
ac716aa
 
 
 
9a46859
 
ac716aa
 
 
8a73b48
9a46859
ac716aa
 
 
9a46859
ac716aa
 
 
 
 
 
9a46859
ac716aa
9a76509
8a73b48
 
9a76509
ac716aa
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

# Tải model và tokenizer khi ứng dụng khởi động
model_name = "Qwen/Qwen2.5-0.5B"
try:
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    # Đặt pad_token_id nếu chưa có
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype="auto",
        device_map="auto",
        attn_implementation="eager"  # Tránh cảnh báo sdpa
    )
    print("Model and tokenizer loaded successfully!")
except Exception as e:
    print(f"Error loading model: {e}")
    raise

# Hàm sinh văn bản (dùng cho cả UI và API)
def generate_text(prompt, max_length, state):
    try:
        # Mã hóa đầu vào với attention_mask
        inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).to(model.device)
        outputs = model.generate(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            max_length=max_length,
            num_return_sequences=1,
            no_repeat_ngram_size=2,
            do_sample=True,
            top_k=50,
            top_p=0.95,
            pad_token_id=tokenizer.pad_token_id
        )
        generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        # Cập nhật state với kết quả mới
        state.append(generated_text)
        return state, generated_text  # Trả về state và output để hiển thị
    except Exception as e:
        error_msg = f"Error: {str(e)}"
        state.append(error_msg)
        return state, error_msg

# Hàm hiển thị thông tin API
def get_api_info():
    base_url = "https://<your-space-name>.hf.space"
    return (
        "Welcome to Qwen2.5-0.5B API!\n"
        f"API Base URL: {base_url} (Replace '<your-space-name>' with your actual Space name)\n"
        "Endpoints:\n"
        f"- GET {base_url}/api/health_check (Check API status)\n"
        f"- POST {base_url}/api/generate (Generate text)\n"
        "To use the generate API, send a POST request with JSON:\n"
        '{"0": "your prompt", "1": 150}'
    )

# Hàm kiểm tra sức khỏe (dành cho API)
def health_check():
    return "Qwen2.5-0.5B API is running!"

# Tạo giao diện Gradio
with gr.Blocks(title="Qwen2.5-0.5B Text Generator") as demo:
    gr.Markdown("# Qwen2.5-0.5B Text Generator")
    gr.Markdown("Enter a prompt below or use the API!")
    
    # State để lưu trữ lịch sử kết quả
    state = gr.State(value=[])  # Khởi tạo state là danh sách rỗng
    
    # Hiển thị thông tin API
    gr.Markdown("### API Information")
    api_info = gr.Textbox(label="API Details", value=get_api_info(), interactive=False)
    
    # Giao diện sinh văn bản
    gr.Markdown("### Generate Text")
    with gr.Row():
        prompt_input = gr.Textbox(label="Prompt", placeholder="Type something...")
        max_length_input = gr.Slider(50, 500, value=100, step=10, label="Max Length")
    
    generate_button = gr.Button("Generate")
    output_text = gr.Textbox(label="Generated Text History", interactive=False, lines=10)
    
    # Liên kết button với hàm generate_text
    generate_button.click(
        fn=generate_text,
        inputs=[prompt_input, max_length_input, state],
        outputs=[state, output_text]  # Cập nhật cả state và output_text
    )

# Định nghĩa API endpoints với Gradio
interface = gr.Interface(
    fn=lambda prompt, max_length: generate_text(prompt, max_length, [])[1],  # Chỉ lấy output, không dùng state cho API
    inputs=["text", "number"],
    outputs="text",
    title="Qwen2.5-0.5B API",
    api_name="/generate"
).queue()

health_interface = gr.Interface(
    fn=health_check,
    inputs=None,
    outputs="text",
    api_name="/health_check"
)

# Gắn các interface vào demo
demo = gr.TabbedInterface([interface, health_interface], ["Generate Text", "Health Check"])

# Chạy ứng dụng
demo.launch(server_name="0.0.0.0", server_port=7860)