veerukhannan commited on
Commit
693be4f
·
verified ·
1 Parent(s): 121ef90

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -5
app.py CHANGED
@@ -23,13 +23,12 @@ logger = logging.getLogger(__name__)
23
  load_dotenv()
24
  login(token=os.getenv("HUGGINGFACE_API_TOKEN"))
25
 
26
- # Initialize model with optimized settings
27
  model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
28
  model = AutoModelForCausalLM.from_pretrained(
29
  model_name,
30
- torch_dtype=torch.float16,
31
  device_map="auto",
32
- load_in_8bit=True
33
  )
34
  tokenizer = AutoTokenizer.from_pretrained(model_name)
35
 
@@ -42,7 +41,7 @@ class LegalTextSearchBot:
42
  )
43
  self.collection = self.astra_db.collection("legal_content")
44
 
45
- # Initialize pipeline with optimized settings
46
  pipe = pipeline(
47
  "text-generation",
48
  model=model,
@@ -51,7 +50,6 @@ class LegalTextSearchBot:
51
  temperature=0.7,
52
  top_p=0.95,
53
  repetition_penalty=1.15,
54
- torch_dtype=torch.float16,
55
  device_map="auto"
56
  )
57
  self.llm = HuggingFacePipeline(pipeline=pipe)
 
23
  load_dotenv()
24
  login(token=os.getenv("HUGGINGFACE_API_TOKEN"))
25
 
26
+ # Initialize model with CPU-compatible settings
27
  model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
28
  model = AutoModelForCausalLM.from_pretrained(
29
  model_name,
 
30
  device_map="auto",
31
+ torch_dtype=torch.float32, # Use float32 for CPU compatibility
32
  )
33
  tokenizer = AutoTokenizer.from_pretrained(model_name)
34
 
 
41
  )
42
  self.collection = self.astra_db.collection("legal_content")
43
 
44
+ # Initialize pipeline with CPU settings
45
  pipe = pipeline(
46
  "text-generation",
47
  model=model,
 
50
  temperature=0.7,
51
  top_p=0.95,
52
  repetition_penalty=1.15,
 
53
  device_map="auto"
54
  )
55
  self.llm = HuggingFacePipeline(pipeline=pipe)