Tijmen2 commited on
Commit
9baf1c4
·
verified ·
1 Parent(s): 5d84ead

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -14
app.py CHANGED
@@ -4,21 +4,11 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
4
  import torch
5
  import random
6
 
7
- # Define model parameters for 8-bit quantized loading
8
- model_name = "AstroMLab/AstroSage-8B"
9
 
10
- # Load the tokenizer
11
- tokenizer = AutoTokenizer.from_pretrained(model_name)
12
-
13
- # Load the model with 8-bit quantization using bitsandbytes
14
- model = AutoModelForCausalLM.from_pretrained(
15
- model_name,
16
- torch_dtype=torch.float16,
17
- load_in_8bit=True, # Enable 8-bit quantization
18
- device_map="auto" # Automatically assign layers to available GPUs
19
- )
20
-
21
- streamer = TextStreamer(tokenizer)
22
 
23
  # Placeholder responses for when context is empty
24
  GREETING_MESSAGES = [
@@ -37,7 +27,24 @@ def user(user_message, history):
37
  @spaces.GPU(duration=20)
38
  def bot(history):
39
  """Generate the chatbot response."""
 
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  if not history:
42
  history = []
43
 
 
4
  import torch
5
  import random
6
 
7
+ MODEL_NAME = "AstroMLab/AstroSage-8B"
 
8
 
9
+ model = None
10
+ tokenizer = None
11
+ streamer = None # these will be initialized the first time the bot function runs
 
 
 
 
 
 
 
 
 
12
 
13
  # Placeholder responses for when context is empty
14
  GREETING_MESSAGES = [
 
27
  @spaces.GPU(duration=20)
28
  def bot(history):
29
  """Generate the chatbot response."""
30
+ global model, tokenizer, streamer
31
 
32
+ if not model:
33
+ # initialize the LLM
34
+
35
+ # Load the tokenizer
36
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
37
+
38
+ # Load the model with 8-bit quantization using bitsandbytes
39
+ model = AutoModelForCausalLM.from_pretrained(
40
+ MODEL_NAME,
41
+ torch_dtype=torch.bfloat16,
42
+ load_in_8bit=True, # Enable 8-bit quantization
43
+ device_map="auto" # Automatically assign layers to available GPUs
44
+ )
45
+
46
+ streamer = TextStreamer(tokenizer)
47
+
48
  if not history:
49
  history = []
50