rapacious commited on
Commit
45d388a
·
verified ·
1 Parent(s): ccd7e94

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -0
app.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ import torch
5
+
6
+ # Khởi tạo FastAPI
7
+ app = FastAPI()
8
+
9
+ # Tải model và tokenizer khi ứng dụng khởi động
10
+ model_name = "Qwen/Qwen2.5-0.5B"
11
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
12
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto")
13
+
14
+ # Định nghĩa request body
15
+ class TextInput(BaseModel):
16
+ prompt: str
17
+ max_length: int = 100
18
+
19
+ # API endpoint để sinh văn bản
20
+ @app.post("/generate")
21
+ async def generate_text(input: TextInput):
22
+ try:
23
+ # Mã hóa đầu vào
24
+ inputs = tokenizer(input.prompt, return_tensors="pt").to(model.device)
25
+
26
+ # Sinh văn bản
27
+ outputs = model.generate(
28
+ inputs["input_ids"],
29
+ max_length=input.max_length,
30
+ num_return_sequences=1,
31
+ no_repeat_ngram_size=2,
32
+ do_sample=True,
33
+ top_k=50,
34
+ top_p=0.95
35
+ )
36
+
37
+ # Giải mã kết quả
38
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
39
+ return {"generated_text": generated_text}
40
+ except Exception as e:
41
+ raise HTTPException(status_code=500, detail=str(e))
42
+
43
+ # Endpoint kiểm tra sức khỏe
44
+ @app.get("/")
45
+ async def root():
46
+ return {"message": "Qwen2.5-0.5B API is running!"}