Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
from fastapi import FastAPI
|
2 |
from pydantic import BaseModel
|
3 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
4 |
import torch
|
5 |
import uvicorn
|
6 |
|
@@ -9,19 +9,15 @@ app = FastAPI()
|
|
9 |
# Model name (update with your actual model path on Hugging Face)
|
10 |
model_name = "waynebruce2110/GraveSocialAI"
|
11 |
|
12 |
-
#
|
13 |
-
quantization_config = BitsAndBytesConfig(
|
14 |
-
load_in_8bit=True, # Enables 8-bit loading
|
15 |
-
llm_int8_enable_fp32_cpu_offload=True # Ensures it works on CPU
|
16 |
-
)
|
17 |
-
|
18 |
-
# Load the tokenizer and model with 8-bit quantization
|
19 |
tokenizer = AutoTokenizer.from_pretrained(model_name, local_files_only=False)
|
|
|
|
|
20 |
model = AutoModelForCausalLM.from_pretrained(
|
21 |
-
model_name,
|
22 |
-
local_files_only=False,
|
23 |
-
|
24 |
-
|
25 |
)
|
26 |
|
27 |
# Define input schema
|
@@ -34,8 +30,9 @@ def read_root():
|
|
34 |
|
35 |
@app.post("/generate/")
|
36 |
def generate_text(data: PromptInput):
|
37 |
-
inputs = tokenizer(data.prompt, return_tensors="pt")
|
38 |
-
|
|
|
39 |
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
40 |
return {"generated_text": response}
|
41 |
|
|
|
1 |
from fastapi import FastAPI
|
2 |
from pydantic import BaseModel
|
3 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
4 |
import torch
|
5 |
import uvicorn
|
6 |
|
|
|
9 |
# Model name (update with your actual model path on Hugging Face)
|
10 |
model_name = "waynebruce2110/GraveSocialAI"
|
11 |
|
12 |
+
# Load tokenizer
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
tokenizer = AutoTokenizer.from_pretrained(model_name, local_files_only=False)
|
14 |
+
|
15 |
+
# Load model with 8-bit quantization on CPU
|
16 |
model = AutoModelForCausalLM.from_pretrained(
|
17 |
+
model_name,
|
18 |
+
local_files_only=False,
|
19 |
+
torch_dtype=torch.float16, # Reduces memory usage
|
20 |
+
device_map="cpu" # Forces model to load on CPU
|
21 |
)
|
22 |
|
23 |
# Define input schema
|
|
|
30 |
|
31 |
@app.post("/generate/")
|
32 |
def generate_text(data: PromptInput):
|
33 |
+
inputs = tokenizer(data.prompt, return_tensors="pt").to("cpu") # Ensure input is on CPU
|
34 |
+
with torch.no_grad():
|
35 |
+
outputs = model.generate(**inputs, max_length=100)
|
36 |
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
37 |
return {"generated_text": response}
|
38 |
|