File size: 3,595 Bytes
9a76509
9f3a452
45d388a
 
 
9a76509
 
45d388a
 
 
 
 
 
59cd1f7
 
 
 
 
 
 
 
 
 
 
 
45d388a
9a76509
45d388a
 
 
 
9a76509
 
45d388a
9a76509
45d388a
 
9a76509
45d388a
 
 
 
 
 
9a76509
 
 
 
 
 
 
 
 
 
45d388a
 
 
9a76509
45d388a
 
59cd1f7
 
9a76509
9f3a452
 
9a76509
 
 
 
 
 
1ec8046
 
 
 
 
9a76509
 
1ec8046
 
9f3a452
9a76509
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59cd1f7
1ec8046
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import gradio as gr
from fastapi import FastAPI, HTTPException, Request
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import uvicorn
from fastapi.responses import HTMLResponse

# Khởi tạo FastAPI
app = FastAPI()

# Tải model và tokenizer khi ứng dụng khởi động
model_name = "Qwen/Qwen2.5-0.5B"
try:
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype="auto",
        device_map="auto",
        attn_implementation="eager"  # Tránh cảnh báo sdpa
    )
    print("Model and tokenizer loaded successfully!")
except Exception as e:
    print(f"Error loading model: {e}")
    raise

# Định nghĩa request body cho API
class TextInput(BaseModel):
    prompt: str
    max_length: int = 100

# Hàm sinh văn bản (dùng chung cho API và Gradio)
def generate_text(prompt, max_length=100):
    try:
        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
        outputs = model.generate(
            inputs["input_ids"],
            max_length=max_length,
            num_return_sequences=1,
            no_repeat_ngram_size=2,
            do_sample=True,
            top_k=50,
            top_p=0.95
        )
        return tokenizer.decode(outputs[0], skip_special_tokens=True)
    except Exception as e:
        raise Exception(f"Error: {str(e)}")

# API endpoint để sinh văn bản
@app.post("/generate")
async def generate_text_api(input: TextInput):
    try:
        result = generate_text(input.prompt, input.max_length)
        return {"generated_text": result}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

# API endpoint kiểm tra sức khỏe
@app.get("/")
async def root():
    return {"message": "Qwen2.5-0.5B API is running!"}

# API endpoint hiển thị URL
@app.get("/api_link")
async def get_api_link(request: Request):
    scheme = request.url.scheme
    host = request.url.hostname
    if request.url.port:
        base_url = f"{scheme}://{host}:{request.url.port}"
    else:
        base_url = f"{scheme}://{host}"
    return {
        "api_url": base_url,
        "endpoints": {
            "health_check": f"{base_url}/",
            "generate_text": f"{base_url}/generate",
            "api_link": f"{base_url}/api_link",
            "interface": f"{base_url}/interface"
        }
    }

# Tạo giao diện Gradio
def create_gradio_interface():
    with gr.Blocks(title="Qwen2.5-0.5B Text Generator") as demo:
        gr.Markdown("# Qwen2.5-0.5B Text Generator")
        gr.Markdown("Enter a prompt and get generated text!")
        
        with gr.Row():
            prompt_input = gr.Textbox(label="Prompt", placeholder="Type something...")
            max_length_input = gr.Slider(50, 500, value=100, step=10, label="Max Length")
        
        generate_button = gr.Button("Generate")
        output_text = gr.Textbox(label="Generated Text", interactive=False)
        
        generate_button.click(
            fn=generate_text,
            inputs=[prompt_input, max_length_input],
            outputs=output_text
        )
    return demo

# Thêm endpoint để hiển thị giao diện Gradio
@app.get("/interface", response_class=HTMLResponse)
async def gradio_interface(request: Request):
    gradio_app = create_gradio_interface()
    return HTMLResponse(content=gradio_app.render())

# Chạy ứng dụng nếu không trên Hugging Face Spaces
if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)