rapacious commited on
Commit
59cd1f7
·
verified ·
1 Parent(s): df8c7cb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -3
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
@@ -8,8 +9,18 @@ app = FastAPI()
8
 
9
  # Tải model và tokenizer khi ứng dụng khởi động
10
  model_name = "Qwen/Qwen2.5-0.5B"
11
- tokenizer = AutoTokenizer.from_pretrained(model_name)
12
- model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto")
 
 
 
 
 
 
 
 
 
 
13
 
14
  # Định nghĩa request body
15
  class TextInput(BaseModel):
@@ -43,4 +54,8 @@ async def generate_text(input: TextInput):
43
  # Endpoint kiểm tra sức khỏe
44
  @app.get("/")
45
  async def root():
46
- return {"message": "Qwen2.5-0.5B API is running!"}
 
 
 
 
 
1
+ import uvicorn
2
  from fastapi import FastAPI, HTTPException
3
  from pydantic import BaseModel
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
9
 
10
  # Tải model và tokenizer khi ứng dụng khởi động
11
  model_name = "Qwen/Qwen2.5-0.5B"
12
+ try:
13
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
14
+ model = AutoModelForCausalLM.from_pretrained(
15
+ model_name,
16
+ torch_dtype="auto",
17
+ device_map="auto",
18
+ attn_implementation="eager" # Tránh cảnh báo sdpa
19
+ )
20
+ print("Model and tokenizer loaded successfully!")
21
+ except Exception as e:
22
+ print(f"Error loading model: {e}")
23
+ raise
24
 
25
  # Định nghĩa request body
26
  class TextInput(BaseModel):
 
54
  # Endpoint kiểm tra sức khỏe
55
  @app.get("/")
56
  async def root():
57
+ return {"message": "Qwen2.5-0.5B API is running!"}
58
+
59
+ # Chạy server khi file được gọi trực tiếp
60
+ if __name__ == "__main__":
61
+ uvicorn.run(app, host="0.0.0.0", port=7860)