CamiloVega commited on
Commit
a9cb5eb
·
verified ·
1 Parent(s): 935cc40

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -65
app.py CHANGED
@@ -1,4 +1,3 @@
1
- import os
2
  import spaces
3
  from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
4
  import gradio as gr
@@ -6,11 +5,6 @@ import torch
6
  import logging
7
  import sys
8
  from accelerate import infer_auto_device_map, init_empty_weights
9
- from huggingface_hub import login
10
- from dotenv import load_dotenv
11
-
12
- # Load environment variables
13
- load_dotenv()
14
 
15
  # Configure logging
16
  logging.basicConfig(
@@ -19,22 +13,8 @@ logging.basicConfig(
19
  )
20
  logger = logging.getLogger(__name__)
21
 
22
- # Get HuggingFace token from environment variable
23
- hf_token = os.getenv('HUGGINGFACE_TOKEN')
24
- if not hf_token:
25
- logger.error("HUGGINGFACE_TOKEN environment variable not found")
26
- raise ValueError("Please set the HUGGINGFACE_TOKEN environment variable")
27
-
28
- # Login to Hugging Face
29
- try:
30
- login(token=hf_token)
31
- logger.info("Successfully logged in to Hugging Face")
32
- except Exception as e:
33
- logger.error(f"Failed to login to Hugging Face: {str(e)}")
34
- raise
35
-
36
  # Define the model name
37
- model_name = "meta-llama/Llama-2-7b-hf"
38
 
39
  try:
40
  logger.info("Starting model initialization...")
@@ -56,14 +36,13 @@ try:
56
  )
57
  logger.info("Tokenizer loaded successfully")
58
 
59
- # Load model with 8-bit quantization
60
  logger.info("Loading model...")
61
  model = AutoModelForCausalLM.from_pretrained(
62
  model_name,
63
  torch_dtype=torch.float16 if device == "cuda" else torch.float32,
64
- trust_remote_code=True,
65
- load_in_8bit=True,
66
- device_map="auto"
67
  )
68
  logger.info("Model loaded successfully")
69
 
@@ -87,14 +66,13 @@ except Exception as e:
87
  raise
88
 
89
  # Configure system message
90
- system_message = """You are AQuaBot, an AI assistant aware of environmental impact.
91
- You help users with any topic while raising awareness about water consumption
92
- in AI. Did you know that training GPT-3 consumed 5.4 million liters of water,
93
- equivalent to the daily consumption of a city of 10,000 people?"""
94
-
95
- # Llama 2 specific tokens
96
- B_INST, E_INST = "[INST]", "[/INST]"
97
- B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
98
 
99
  # Constants for water consumption calculation
100
  WATER_PER_TOKEN = {
@@ -105,6 +83,7 @@ WATER_PER_TOKEN = {
105
  }
106
 
107
  # Initialize variables
 
108
  total_water_consumption = 0
109
 
110
  def calculate_tokens(text):
@@ -120,33 +99,30 @@ def calculate_water_consumption(text, is_input=True):
120
  return tokens * (WATER_PER_TOKEN["input_training"] + WATER_PER_TOKEN["input_inference"])
121
  return tokens * (WATER_PER_TOKEN["output_training"] + WATER_PER_TOKEN["output_inference"])
122
 
123
- def format_prompt(user_input, chat_history):
124
- """
125
- Format the prompt according to Llama 2 specific style
126
- """
127
- prompt = f"{B_INST}{B_SYS}{system_message}{E_SYS}"
128
-
129
- if chat_history:
130
- for user_msg, assistant_msg in chat_history:
131
- prompt += f"{user_msg}{E_INST}{assistant_msg}{B_INST}"
132
-
133
- prompt += f"{user_input}{E_INST}"
134
-
135
- return prompt
136
-
137
  @spaces.GPU(duration=60)
138
  @torch.inference_mode()
139
  def generate_response(user_input, chat_history):
140
  try:
141
  logger.info("Generating response for user input...")
142
- global total_water_consumption
143
 
144
  # Calculate water consumption for input
145
  input_water_consumption = calculate_water_consumption(user_input, True)
146
  total_water_consumption += input_water_consumption
147
 
148
- # Format prompt for Llama 2
149
- prompt = format_prompt(user_input, chat_history)
 
 
 
 
 
 
 
 
 
 
 
150
 
151
  logger.info("Generating model response...")
152
  outputs = model_gen(
@@ -163,8 +139,11 @@ def generate_response(user_input, chat_history):
163
  output_water_consumption = calculate_water_consumption(assistant_response, False)
164
  total_water_consumption += output_water_consumption
165
 
 
 
 
166
  # Update chat history
167
- chat_history.append([user_input, assistant_response])
168
 
169
  # Prepare water consumption message
170
  water_message = f"""
@@ -186,7 +165,7 @@ def generate_response(user_input, chat_history):
186
  except Exception as e:
187
  logger.error(f"Error in generate_response: {str(e)}")
188
  error_message = f"An error occurred: {str(e)}"
189
- chat_history.append([user_input, error_message])
190
  return chat_history, show_water
191
 
192
  # Create Gradio interface
@@ -197,13 +176,13 @@ try:
197
  <div style="text-align: center; max-width: 800px; margin: 0 auto; padding: 20px;">
198
  <h1 style="color: #2d333a;">AQuaBot</h1>
199
  <p style="color: #4a5568;">
200
- Welcome to AQuaBot - An AI assistant powered by Llama 2 that helps raise awareness
201
- about water consumption in language models.
202
  </p>
203
  </div>
204
  """)
205
 
206
- chatbot = gr.Chatbot()
207
  message = gr.Textbox(
208
  placeholder="Type your message here...",
209
  show_label=False
@@ -223,7 +202,7 @@ try:
223
  """)
224
  clear = gr.Button("Clear Chat")
225
 
226
- # Add footer with citation, disclaimer, and credits
227
  gr.HTML("""
228
  <div style="text-align: center; max-width: 800px; margin: 20px auto; padding: 20px;
229
  background-color: #f8f9fa; border-radius: 10px;">
@@ -237,15 +216,10 @@ try:
237
  </div>
238
  <div style="border-top: 1px solid #ddd; padding-top: 15px;">
239
  <p style="color: #666; font-size: 14px;">
240
- <strong>Model Information:</strong> This application uses Meta's Llama 2 (7B) model,
241
- a state-of-the-art language model fine-tuned for chat interactions. Water consumption
242
- calculations are based on the methodology from the cited paper.
243
- </p>
244
- </div>
245
- <div style="border-top: 1px solid #ddd; margin-top: 15px; padding-top: 15px;">
246
- <p style="color: #666; font-size: 14px;">
247
- Created by Camilo Vega - AI Consultant<br>
248
- <a href="https://github.com/vegadevs/aquabot" target="_blank">GitHub Repository</a>
249
  </p>
250
  </div>
251
  </div>
 
 
1
  import spaces
2
  from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
3
  import gradio as gr
 
5
  import logging
6
  import sys
7
  from accelerate import infer_auto_device_map, init_empty_weights
 
 
 
 
 
8
 
9
  # Configure logging
10
  logging.basicConfig(
 
13
  )
14
  logger = logging.getLogger(__name__)
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  # Define the model name
17
+ model_name = "microsoft/phi-2"
18
 
19
  try:
20
  logger.info("Starting model initialization...")
 
36
  )
37
  logger.info("Tokenizer loaded successfully")
38
 
39
+ # Load model
40
  logger.info("Loading model...")
41
  model = AutoModelForCausalLM.from_pretrained(
42
  model_name,
43
  torch_dtype=torch.float16 if device == "cuda" else torch.float32,
44
+ device_map="auto",
45
+ trust_remote_code=True
 
46
  )
47
  logger.info("Model loaded successfully")
48
 
 
66
  raise
67
 
68
  # Configure system message
69
+ system_message = {
70
+ "role": "system",
71
+ "content": """You are AQuaBot, an AI assistant aware of environmental impact.
72
+ You help users with any topic while raising awareness about water consumption
73
+ in AI. Did you know that training GPT-3 consumed 5.4 million liters of water,
74
+ equivalent to the daily consumption of a city of 10,000 people?"""
75
+ }
 
76
 
77
  # Constants for water consumption calculation
78
  WATER_PER_TOKEN = {
 
83
  }
84
 
85
  # Initialize variables
86
+ messages = [system_message]
87
  total_water_consumption = 0
88
 
89
  def calculate_tokens(text):
 
99
  return tokens * (WATER_PER_TOKEN["input_training"] + WATER_PER_TOKEN["input_inference"])
100
  return tokens * (WATER_PER_TOKEN["output_training"] + WATER_PER_TOKEN["output_inference"])
101
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  @spaces.GPU(duration=60)
103
  @torch.inference_mode()
104
  def generate_response(user_input, chat_history):
105
  try:
106
  logger.info("Generating response for user input...")
107
+ global total_water_consumption, messages
108
 
109
  # Calculate water consumption for input
110
  input_water_consumption = calculate_water_consumption(user_input, True)
111
  total_water_consumption += input_water_consumption
112
 
113
+ # Add user input to messages
114
+ messages.append({"role": "user", "content": user_input})
115
+
116
+ # Create prompt
117
+ prompt = ""
118
+ for m in messages:
119
+ if m["role"] == "system":
120
+ prompt += f"<START SYSTEM MESSAGE>\n{m['content']}\n<END SYSTEM MESSAGE>\n\n"
121
+ elif m["role"] == "user":
122
+ prompt += f"User: {m['content']}\n"
123
+ else:
124
+ prompt += f"Assistant: {m['content']}\n"
125
+ prompt += "Assistant:"
126
 
127
  logger.info("Generating model response...")
128
  outputs = model_gen(
 
139
  output_water_consumption = calculate_water_consumption(assistant_response, False)
140
  total_water_consumption += output_water_consumption
141
 
142
+ # Add assistant's response to messages
143
+ messages.append({"role": "assistant", "content": assistant_response})
144
+
145
  # Update chat history
146
+ chat_history.append((user_input, assistant_response))
147
 
148
  # Prepare water consumption message
149
  water_message = f"""
 
165
  except Exception as e:
166
  logger.error(f"Error in generate_response: {str(e)}")
167
  error_message = f"An error occurred: {str(e)}"
168
+ chat_history.append((user_input, error_message))
169
  return chat_history, show_water
170
 
171
  # Create Gradio interface
 
176
  <div style="text-align: center; max-width: 800px; margin: 0 auto; padding: 20px;">
177
  <h1 style="color: #2d333a;">AQuaBot</h1>
178
  <p style="color: #4a5568;">
179
+ Welcome to AQuaBot - An AI assistant that helps raise awareness about water
180
+ consumption in language models.
181
  </p>
182
  </div>
183
  """)
184
 
185
+ chatbot = gr.Chatbot(type="messages")
186
  message = gr.Textbox(
187
  placeholder="Type your message here...",
188
  show_label=False
 
202
  """)
203
  clear = gr.Button("Clear Chat")
204
 
205
+ # Add footer with citation and disclaimer
206
  gr.HTML("""
207
  <div style="text-align: center; max-width: 800px; margin: 20px auto; padding: 20px;
208
  background-color: #f8f9fa; border-radius: 10px;">
 
216
  </div>
217
  <div style="border-top: 1px solid #ddd; padding-top: 15px;">
218
  <p style="color: #666; font-size: 14px;">
219
+ <strong>Important note:</strong> This application uses Microsoft's Phi-2 model
220
+ instead of GPT-3 for availability and cost reasons. However,
221
+ the water consumption calculations per token (input/output) are based on the
222
+ conclusions from the cited paper.
 
 
 
 
 
223
  </p>
224
  </div>
225
  </div>