OnlyCheeini commited on
Commit
78f2093
·
verified ·
1 Parent(s): 301b6fb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -2
app.py CHANGED
@@ -12,7 +12,10 @@ app = FastAPI()
12
  # Load your fine-tuned model and tokenizer
13
  model_name = "OnlyCheeini/greesychat-turbo"
14
  tokenizer = AutoTokenizer.from_pretrained(model_name)
15
- model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16).to("cuda")
 
 
 
16
 
17
  class OpenAIRequest(BaseModel):
18
  model: str
@@ -29,7 +32,7 @@ async def generate_text(request: OpenAIRequest):
29
  if request.model != model_name:
30
  raise HTTPException(status_code=400, detail="Model not found")
31
 
32
- inputs = tokenizer(request.prompt, return_tensors="pt").to("cuda")
33
  outputs = model.generate(
34
  **inputs,
35
  max_length=inputs['input_ids'].shape[1] + request.max_tokens,
 
12
  # Load your fine-tuned model and tokenizer
13
  model_name = "OnlyCheeini/greesychat-turbo"
14
  tokenizer = AutoTokenizer.from_pretrained(model_name)
15
+
16
+ # Check if a GPU is available, otherwise use CPU
17
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32).to(device)
19
 
20
  class OpenAIRequest(BaseModel):
21
  model: str
 
32
  if request.model != model_name:
33
  raise HTTPException(status_code=400, detail="Model not found")
34
 
35
+ inputs = tokenizer(request.prompt, return_tensors="pt").to(device)
36
  outputs = model.generate(
37
  **inputs,
38
  max_length=inputs['input_ids'].shape[1] + request.max_tokens,