CamiloVega commited on
Commit
bd4d51a
·
verified ·
1 Parent(s): e5add36

Upload app (4).py

Browse files
Files changed (1) hide show
  1. app (4).py +342 -0
app (4).py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
3
+ import gradio as gr
4
+ import torch
5
+ import logging
6
+ import sys
7
+ import os
8
+ from accelerate import infer_auto_device_map, init_empty_weights
9
+
10
+ # Configure logging
11
+ logging.basicConfig(
12
+ level=logging.INFO,
13
+ format='%(asctime)s - %(levelname)s - %(message)s'
14
+ )
15
+ logger = logging.getLogger(__name__)
16
+
17
+ # Get HuggingFace token from environment variable
18
+ hf_token = os.environ.get('HUGGINGFACE_TOKEN')
19
+ if not hf_token:
20
+ logger.error("HUGGINGFACE_TOKEN environment variable not set")
21
+ raise ValueError("Please set the HUGGINGFACE_TOKEN environment variable")
22
+
23
+ # Define the model name
24
+ model_name = "meta-llama/Llama-2-7b-hf"
25
+
26
+ try:
27
+ logger.info("Starting model initialization...")
28
+
29
+ # Check CUDA availability
30
+ device = "cuda" if torch.cuda.is_available() else "cpu"
31
+ logger.info(f"Using device: {device}")
32
+
33
+ # Configure PyTorch settings
34
+ if device == "cuda":
35
+ torch.backends.cuda.matmul.allow_tf32 = True
36
+ torch.backends.cudnn.allow_tf32 = True
37
+
38
+ # Load tokenizer
39
+ logger.info("Loading tokenizer...")
40
+ tokenizer = AutoTokenizer.from_pretrained(
41
+ model_name,
42
+ trust_remote_code=True,
43
+ token=hf_token
44
+ )
45
+ tokenizer.pad_token = tokenizer.eos_token
46
+ logger.info("Tokenizer loaded successfully")
47
+
48
+ # Load model with basic configuration
49
+ logger.info("Loading model...")
50
+ model = AutoModelForCausalLM.from_pretrained(
51
+ model_name,
52
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
53
+ trust_remote_code=True,
54
+ token=hf_token,
55
+ device_map="auto"
56
+ )
57
+ logger.info("Model loaded successfully")
58
+
59
+ # Create pipeline
60
+ logger.info("Creating generation pipeline...")
61
+ model_gen = pipeline(
62
+ "text-generation",
63
+ model=model,
64
+ tokenizer=tokenizer,
65
+ max_new_tokens=256,
66
+ do_sample=True,
67
+ temperature=0.7,
68
+ top_p=0.9,
69
+ repetition_penalty=1.1,
70
+ device_map="auto"
71
+ )
72
+ logger.info("Pipeline created successfully")
73
+
74
+ except Exception as e:
75
+ logger.error(f"Error during initialization: {str(e)}")
76
+ raise
77
+
78
+ # Configure system message
79
+
80
+ system_message = """You are a helpful AI assistant called AQuaBot. You provide direct, clear, and detailed answers to questions while being aware of environmental impact. Keep your responses natural and informative, but concise. Always provide context and explanations with your answers. Respond directly to questions without using any special tags or markers."""
81
+
82
+ @spaces.GPU(duration=60)
83
+ @torch.inference_mode()
84
+ def generate_response(user_input, chat_history):
85
+ try:
86
+ logger.info("Generating response for user input...")
87
+ global total_water_consumption
88
+
89
+ # Calculate water consumption for input
90
+ input_water_consumption = calculate_water_consumption(user_input, True)
91
+ total_water_consumption += input_water_consumption
92
+
93
+ # Create prompt with Llama 2 chat format
94
+ conversation_history = ""
95
+ if chat_history:
96
+ for message in chat_history:
97
+ # Remove any [INST] tags from the history
98
+ user_msg = message[0].replace("[INST]", "").replace("[/INST]", "").strip()
99
+ assistant_msg = message[1].replace("[INST]", "").replace("[/INST]", "").strip()
100
+ conversation_history += f"[INST] {user_msg} [/INST] {assistant_msg} "
101
+
102
+ prompt = f"<s>[INST] {system_message}\n\n{conversation_history}[INST] {user_input} [/INST]"
103
+
104
+ logger.info("Generating model response...")
105
+ outputs = model_gen(
106
+ prompt,
107
+ max_new_tokens=256,
108
+ return_full_text=False,
109
+ pad_token_id=tokenizer.eos_token_id,
110
+ do_sample=True,
111
+ temperature=0.7,
112
+ top_p=0.9,
113
+ repetition_penalty=1.1
114
+ )
115
+ logger.info("Model response generated successfully")
116
+
117
+ # Clean up the response by removing any [INST] tags and trimming
118
+ assistant_response = outputs[0]['generated_text'].strip()
119
+ assistant_response = assistant_response.replace("[INST]", "").replace("[/INST]", "").strip()
120
+
121
+ # If the response is too short, try to generate a more detailed one
122
+ if len(assistant_response.split()) < 10:
123
+ prompt += "\nPlease provide a more detailed answer with context and explanation."
124
+ outputs = model_gen(
125
+ prompt,
126
+ max_new_tokens=256,
127
+ return_full_text=False,
128
+ pad_token_id=tokenizer.eos_token_id,
129
+ do_sample=True,
130
+ temperature=0.7,
131
+ top_p=0.9,
132
+ repetition_penalty=1.1
133
+ )
134
+ assistant_response = outputs[0]['generated_text'].strip()
135
+ assistant_response = assistant_response.replace("[INST]", "").replace("[/INST]", "").strip()
136
+
137
+ # Calculate water consumption for output
138
+ output_water_consumption = calculate_water_consumption(assistant_response, False)
139
+ total_water_consumption += output_water_consumption
140
+
141
+ # Update chat history with the cleaned messages
142
+ chat_history.append([user_input, assistant_response])
143
+
144
+ # Prepare water consumption message
145
+ water_message = f"""
146
+ <div style="position: fixed; top: 20px; right: 20px;
147
+ background-color: white; padding: 15px;
148
+ border: 2px solid #ff0000; border-radius: 10px;
149
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);">
150
+ <div style="color: #ff0000; font-size: 24px; font-weight: bold;">
151
+ 💧 {total_water_consumption:.4f} ml
152
+ </div>
153
+ <div style="color: #666; font-size: 14px;">
154
+ Water Consumed
155
+ </div>
156
+ </div>
157
+ """
158
+
159
+ return chat_history, water_message
160
+
161
+ except Exception as e:
162
+ logger.error(f"Error in generate_response: {str(e)}")
163
+ error_message = f"An error occurred: {str(e)}"
164
+ chat_history.append([user_input, error_message])
165
+ return chat_history, show_water
166
+
167
+ # Constants for water consumption calculation
168
+ WATER_PER_TOKEN = {
169
+ "input_training": 0.0000309,
170
+ "output_training": 0.0000309,
171
+ "input_inference": 0.05,
172
+ "output_inference": 0.05
173
+ }
174
+
175
+ # Initialize variables
176
+ total_water_consumption = 0
177
+
178
+ def calculate_tokens(text):
179
+ try:
180
+ return len(tokenizer.encode(text))
181
+ except Exception as e:
182
+ logger.error(f"Error calculating tokens: {str(e)}")
183
+ return len(text.split()) + len(text) // 4 # Fallback to approximation
184
+
185
+ def calculate_water_consumption(text, is_input=True):
186
+ tokens = calculate_tokens(text)
187
+ if is_input:
188
+ return tokens * (WATER_PER_TOKEN["input_training"] + WATER_PER_TOKEN["input_inference"])
189
+ return tokens * (WATER_PER_TOKEN["output_training"] + WATER_PER_TOKEN["output_inference"])
190
+
191
+ def format_message(role, content):
192
+ return {"role": role, "content": content}
193
+
194
+ @spaces.GPU(duration=60)
195
+ @torch.inference_mode()
196
+ def generate_response(user_input, chat_history):
197
+ try:
198
+ logger.info("Generating response for user input...")
199
+ global total_water_consumption
200
+
201
+ # Calculate water consumption for input
202
+ input_water_consumption = calculate_water_consumption(user_input, True)
203
+ total_water_consumption += input_water_consumption
204
+
205
+ # Create prompt with Llama 2 chat format
206
+ conversation_history = ""
207
+ if chat_history:
208
+ for message in chat_history:
209
+ conversation_history += f"[INST] {message[0]} [/INST] {message[1]} "
210
+
211
+ prompt = f"<s>[INST] {system_message}\n\n{conversation_history}[INST] {user_input} [/INST]"
212
+
213
+ logger.info("Generating model response...")
214
+ outputs = model_gen(
215
+ prompt,
216
+ max_new_tokens=256,
217
+ return_full_text=False,
218
+ pad_token_id=tokenizer.eos_token_id,
219
+ )
220
+ logger.info("Model response generated successfully")
221
+
222
+ assistant_response = outputs[0]['generated_text'].strip()
223
+
224
+ # Calculate water consumption for output
225
+ output_water_consumption = calculate_water_consumption(assistant_response, False)
226
+ total_water_consumption += output_water_consumption
227
+
228
+ # Update chat history with the new formatted messages
229
+ chat_history.append([user_input, assistant_response])
230
+
231
+ # Prepare water consumption message
232
+ water_message = f"""
233
+ <div style="position: fixed; top: 20px; right: 20px;
234
+ background-color: white; padding: 15px;
235
+ border: 2px solid #ff0000; border-radius: 10px;
236
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);">
237
+ <div style="color: #ff0000; font-size: 24px; font-weight: bold;">
238
+ 💧 {total_water_consumption:.4f} ml
239
+ </div>
240
+ <div style="color: #666; font-size: 14px;">
241
+ Water Consumed
242
+ </div>
243
+ </div>
244
+ """
245
+
246
+ return chat_history, water_message
247
+
248
+ except Exception as e:
249
+ logger.error(f"Error in generate_response: {str(e)}")
250
+ error_message = f"An error occurred: {str(e)}"
251
+ chat_history.append([user_input, error_message])
252
+ return chat_history, show_water
253
+
254
+ # Create Gradio interface
255
+ try:
256
+ logger.info("Creating Gradio interface...")
257
+ with gr.Blocks(css="div.gradio-container {background-color: #f0f2f6}") as demo:
258
+ gr.HTML("""
259
+ <div style="text-align: center; max-width: 800px; margin: 0 auto; padding: 20px;">
260
+ <h1 style="color: #2d333a;">AQuaBot</h1>
261
+ <p style="color: #4a5568;">
262
+ Welcome to AQuaBot - An AI assistant that helps raise awareness
263
+ about water consumption in language models.
264
+ </p>
265
+ </div>
266
+ """)
267
+
268
+ chatbot = gr.Chatbot()
269
+ message = gr.Textbox(
270
+ placeholder="Type your message here...",
271
+ show_label=False
272
+ )
273
+ show_water = gr.HTML(f"""
274
+ <div style="position: fixed; top: 20px; right: 20px;
275
+ background-color: white; padding: 15px;
276
+ border: 2px solid #ff0000; border-radius: 10px;
277
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);">
278
+ <div style="color: #ff0000; font-size: 24px; font-weight: bold;">
279
+ 💧 0.0000 ml
280
+ </div>
281
+ <div style="color: #666; font-size: 14px;">
282
+ Water Consumed
283
+ </div>
284
+ </div>
285
+ """)
286
+ clear = gr.Button("Clear Chat")
287
+
288
+ # Add footer with citation and disclaimer
289
+ gr.HTML("""
290
+ <div style="text-align: center; max-width: 800px; margin: 20px auto; padding: 20px;
291
+ background-color: #f8f9fa; border-radius: 10px;">
292
+ <div style="margin-bottom: 15px;">
293
+ <p style="color: #666; font-size: 14px; font-style: italic;">
294
+ Water consumption calculations are based on the study:<br>
295
+ Li, P. et al. (2023). Making AI Less Thirsty: Uncovering and Addressing the Secret Water
296
+ Footprint of AI Models. ArXiv Preprint,
297
+ <a href="https://arxiv.org/abs/2304.03271" target="_blank">https://arxiv.org/abs/2304.03271</a>
298
+ </p>
299
+ </div>
300
+ <div style="border-top: 1px solid #ddd; padding-top: 15px;">
301
+ <p style="color: #666; font-size: 14px;">
302
+ <strong>Important note:</strong> This application uses Meta Llama-2-7b model
303
+ instead of GPT-3 for availability and cost reasons. However,
304
+ the water consumption calculations per token (input/output) are based on the
305
+ conclusions from the cited paper.
306
+ </p>
307
+ </div>
308
+ </div>
309
+ """)
310
+
311
+ def submit(user_input, chat_history):
312
+ return generate_response(user_input, chat_history)
313
+
314
+ # Configure event handlers
315
+ message.submit(submit, [message, chatbot], [chatbot, show_water])
316
+ clear.click(
317
+ lambda: ([], f"""
318
+ <div style="position: fixed; top: 20px; right: 20px;
319
+ background-color: white; padding: 15px;
320
+ border: 2px solid #ff0000; border-radius: 10px;
321
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);">
322
+ <div style="color: #ff0000; font-size: 24px; font-weight: bold;">
323
+ 💧 0.0000 ml
324
+ </div>
325
+ <div style="color: #666; font-size: 14px;">
326
+ Water Consumed
327
+ </div>
328
+ </div>
329
+ """),
330
+ None,
331
+ [chatbot, show_water]
332
+ )
333
+
334
+ logger.info("Gradio interface created successfully")
335
+
336
+ # Launch the application
337
+ logger.info("Launching application...")
338
+ demo.launch()
339
+
340
+ except Exception as e:
341
+ logger.error(f"Error in Gradio interface creation: {str(e)}")
342
+ raise