waynebruce2110 commited on
Commit
1fa4878
·
verified ·
1 Parent(s): 8a21366

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -14
app.py CHANGED
@@ -1,6 +1,6 @@
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
4
  import torch
5
  import uvicorn
6
 
@@ -9,19 +9,15 @@ app = FastAPI()
9
  # Model name (update with your actual model path on Hugging Face)
10
  model_name = "waynebruce2110/GraveSocialAI"
11
 
12
- # Enable 8-bit quantization for CPU
13
- quantization_config = BitsAndBytesConfig(
14
- load_in_8bit=True, # Enables 8-bit loading
15
- llm_int8_enable_fp32_cpu_offload=True # Ensures it works on CPU
16
- )
17
-
18
- # Load the tokenizer and model with 8-bit quantization
19
  tokenizer = AutoTokenizer.from_pretrained(model_name, local_files_only=False)
 
 
20
  model = AutoModelForCausalLM.from_pretrained(
21
- model_name,
22
- local_files_only=False,
23
- device_map="cpu", # Ensures it loads on CPU
24
- quantization_config=quantization_config
25
  )
26
 
27
  # Define input schema
@@ -34,8 +30,9 @@ def read_root():
34
 
35
  @app.post("/generate/")
36
  def generate_text(data: PromptInput):
37
- inputs = tokenizer(data.prompt, return_tensors="pt")
38
- outputs = model.generate(**inputs, max_length=100)
 
39
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
40
  return {"generated_text": response}
41
 
 
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import torch
5
  import uvicorn
6
 
 
9
  # Model name (update with your actual model path on Hugging Face)
10
  model_name = "waynebruce2110/GraveSocialAI"
11
 
12
+ # Load tokenizer
 
 
 
 
 
 
13
  tokenizer = AutoTokenizer.from_pretrained(model_name, local_files_only=False)
14
+
15
+ # Load model with 8-bit quantization on CPU
16
  model = AutoModelForCausalLM.from_pretrained(
17
+ model_name,
18
+ local_files_only=False,
19
+ torch_dtype=torch.float16, # Reduces memory usage
20
+ device_map="cpu" # Forces model to load on CPU
21
  )
22
 
23
  # Define input schema
 
30
 
31
  @app.post("/generate/")
32
  def generate_text(data: PromptInput):
33
+ inputs = tokenizer(data.prompt, return_tensors="pt").to("cpu") # Ensure input is on CPU
34
+ with torch.no_grad():
35
+ outputs = model.generate(**inputs, max_length=100)
36
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
37
  return {"generated_text": response}
38