File size: 3,595 Bytes
9a76509 9f3a452 45d388a 9a76509 45d388a 59cd1f7 45d388a 9a76509 45d388a 9a76509 45d388a 9a76509 45d388a 9a76509 45d388a 9a76509 45d388a 9a76509 45d388a 59cd1f7 9a76509 9f3a452 9a76509 1ec8046 9a76509 1ec8046 9f3a452 9a76509 59cd1f7 1ec8046 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
import gradio as gr
from fastapi import FastAPI, HTTPException, Request
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import uvicorn
from fastapi.responses import HTMLResponse
# Khởi tạo FastAPI
app = FastAPI()
# Tải model và tokenizer khi ứng dụng khởi động
model_name = "Qwen/Qwen2.5-0.5B"
try:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
device_map="auto",
attn_implementation="eager" # Tránh cảnh báo sdpa
)
print("Model and tokenizer loaded successfully!")
except Exception as e:
print(f"Error loading model: {e}")
raise
# Định nghĩa request body cho API
class TextInput(BaseModel):
prompt: str
max_length: int = 100
# Hàm sinh văn bản (dùng chung cho API và Gradio)
def generate_text(prompt, max_length=100):
try:
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
inputs["input_ids"],
max_length=max_length,
num_return_sequences=1,
no_repeat_ngram_size=2,
do_sample=True,
top_k=50,
top_p=0.95
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
except Exception as e:
raise Exception(f"Error: {str(e)}")
# API endpoint để sinh văn bản
@app.post("/generate")
async def generate_text_api(input: TextInput):
try:
result = generate_text(input.prompt, input.max_length)
return {"generated_text": result}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# API endpoint kiểm tra sức khỏe
@app.get("/")
async def root():
return {"message": "Qwen2.5-0.5B API is running!"}
# API endpoint hiển thị URL
@app.get("/api_link")
async def get_api_link(request: Request):
scheme = request.url.scheme
host = request.url.hostname
if request.url.port:
base_url = f"{scheme}://{host}:{request.url.port}"
else:
base_url = f"{scheme}://{host}"
return {
"api_url": base_url,
"endpoints": {
"health_check": f"{base_url}/",
"generate_text": f"{base_url}/generate",
"api_link": f"{base_url}/api_link",
"interface": f"{base_url}/interface"
}
}
# Tạo giao diện Gradio
def create_gradio_interface():
with gr.Blocks(title="Qwen2.5-0.5B Text Generator") as demo:
gr.Markdown("# Qwen2.5-0.5B Text Generator")
gr.Markdown("Enter a prompt and get generated text!")
with gr.Row():
prompt_input = gr.Textbox(label="Prompt", placeholder="Type something...")
max_length_input = gr.Slider(50, 500, value=100, step=10, label="Max Length")
generate_button = gr.Button("Generate")
output_text = gr.Textbox(label="Generated Text", interactive=False)
generate_button.click(
fn=generate_text,
inputs=[prompt_input, max_length_input],
outputs=output_text
)
return demo
# Thêm endpoint để hiển thị giao diện Gradio
@app.get("/interface", response_class=HTMLResponse)
async def gradio_interface(request: Request):
gradio_app = create_gradio_interface()
return HTMLResponse(content=gradio_app.render())
# Chạy ứng dụng nếu không trên Hugging Face Spaces
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860) |