CamiloVega commited on
Commit
0a8b519
·
verified ·
1 Parent(s): 0e8cca3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +259 -0
app.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
3
+ import gradio as gr
4
+ import torch
5
+ import logging
6
+ import sys
7
+ from accelerate import infer_auto_device_map, init_empty_weights
8
+
9
+ # Configure logging
10
+ logging.basicConfig(
11
+ level=logging.INFO,
12
+ format='%(asctime)s - %(levelname)s - %(message)s'
13
+ )
14
+ logger = logging.getLogger(__name__)
15
+
16
+ # Define the model name
17
+ model_name = "microsoft/phi-2"
18
+
19
+ try:
20
+ logger.info("Starting model initialization...")
21
+
22
+ # Check CUDA availability
23
+ device = "cuda" if torch.cuda.is_available() else "cpu"
24
+ logger.info(f"Using device: {device}")
25
+
26
+ # Configure PyTorch settings
27
+ if device == "cuda":
28
+ torch.backends.cuda.matmul.allow_tf32 = True
29
+ torch.backends.cudnn.allow_tf32 = True
30
+
31
+ # Load tokenizer
32
+ logger.info("Loading tokenizer...")
33
+ tokenizer = AutoTokenizer.from_pretrained(
34
+ model_name,
35
+ trust_remote_code=True
36
+ )
37
+ logger.info("Tokenizer loaded successfully")
38
+
39
+ # Load model
40
+ logger.info("Loading model...")
41
+ model = AutoModelForCausalLM.from_pretrained(
42
+ model_name,
43
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
44
+ device_map="auto",
45
+ trust_remote_code=True
46
+ )
47
+ logger.info("Model loaded successfully")
48
+
49
+ # Create pipeline
50
+ logger.info("Creating generation pipeline...")
51
+ model_gen = pipeline(
52
+ "text-generation",
53
+ model=model,
54
+ tokenizer=tokenizer,
55
+ max_new_tokens=256,
56
+ do_sample=True,
57
+ temperature=0.7,
58
+ top_p=0.9,
59
+ repetition_penalty=1.1,
60
+ device_map="auto"
61
+ )
62
+ logger.info("Pipeline created successfully")
63
+
64
+ except Exception as e:
65
+ logger.error(f"Error during initialization: {str(e)}")
66
+ raise
67
+
68
+ # Configure system message
69
+ system_message = {
70
+ "role": "system",
71
+ "content": """You are AQuaBot, an AI assistant aware of environmental impact.
72
+ You help users with any topic while raising awareness about water consumption
73
+ in AI. Did you know that training GPT-3 consumed 5.4 million liters of water,
74
+ equivalent to the daily consumption of a city of 10,000 people?"""
75
+ }
76
+
77
+ # Constants for water consumption calculation
78
+ WATER_PER_TOKEN = {
79
+ "input_training": 0.0000309,
80
+ "output_training": 0.0000309,
81
+ "input_inference": 0.05,
82
+ "output_inference": 0.05
83
+ }
84
+
85
+ # Initialize variables
86
+ messages = [system_message]
87
+ total_water_consumption = 0
88
+
89
+ def calculate_tokens(text):
90
+ try:
91
+ return len(tokenizer.encode(text))
92
+ except Exception as e:
93
+ logger.error(f"Error calculating tokens: {str(e)}")
94
+ return len(text.split()) + len(text) // 4 # Fallback to approximation
95
+
96
+ def calculate_water_consumption(text, is_input=True):
97
+ tokens = calculate_tokens(text)
98
+ if is_input:
99
+ return tokens * (WATER_PER_TOKEN["input_training"] + WATER_PER_TOKEN["input_inference"])
100
+ return tokens * (WATER_PER_TOKEN["output_training"] + WATER_PER_TOKEN["output_inference"])
101
+
102
+ @spaces.GPU(duration=60)
103
+ @torch.inference_mode()
104
+ def generate_response(user_input, chat_history):
105
+ try:
106
+ logger.info("Generating response for user input...")
107
+ global total_water_consumption, messages
108
+
109
+ # Calculate water consumption for input
110
+ input_water_consumption = calculate_water_consumption(user_input, True)
111
+ total_water_consumption += input_water_consumption
112
+
113
+ # Add user input to messages
114
+ messages.append({"role": "user", "content": user_input})
115
+
116
+ # Create prompt
117
+ prompt = ""
118
+ for m in messages:
119
+ if m["role"] == "system":
120
+ prompt += f"<START SYSTEM MESSAGE>\n{m['content']}\n<END SYSTEM MESSAGE>\n\n"
121
+ elif m["role"] == "user":
122
+ prompt += f"User: {m['content']}\n"
123
+ else:
124
+ prompt += f"Assistant: {m['content']}\n"
125
+ prompt += "Assistant:"
126
+
127
+ logger.info("Generating model response...")
128
+ outputs = model_gen(
129
+ prompt,
130
+ max_new_tokens=256,
131
+ return_full_text=False,
132
+ pad_token_id=tokenizer.eos_token_id,
133
+ )
134
+ logger.info("Model response generated successfully")
135
+
136
+ assistant_response = outputs[0]['generated_text'].strip()
137
+
138
+ # Calculate water consumption for output
139
+ output_water_consumption = calculate_water_consumption(assistant_response, False)
140
+ total_water_consumption += output_water_consumption
141
+
142
+ # Add assistant's response to messages
143
+ messages.append({"role": "assistant", "content": assistant_response})
144
+
145
+ # Update chat history
146
+ chat_history.append((user_input, assistant_response))
147
+
148
+ # Prepare water consumption message
149
+ water_message = f"""
150
+ <div style="position: fixed; top: 20px; right: 20px;
151
+ background-color: white; padding: 15px;
152
+ border: 2px solid #ff0000; border-radius: 10px;
153
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);">
154
+ <div style="color: #ff0000; font-size: 24px; font-weight: bold;">
155
+ 💧 {total_water_consumption:.4f} ml
156
+ </div>
157
+ <div style="color: #666; font-size: 14px;">
158
+ Water Consumed
159
+ </div>
160
+ </div>
161
+ """
162
+
163
+ return chat_history, water_message
164
+
165
+ except Exception as e:
166
+ logger.error(f"Error in generate_response: {str(e)}")
167
+ error_message = f"An error occurred: {str(e)}"
168
+ chat_history.append((user_input, error_message))
169
+ return chat_history, show_water
170
+
171
+ # Create Gradio interface
172
+ try:
173
+ logger.info("Creating Gradio interface...")
174
+ with gr.Blocks(css="div.gradio-container {background-color: #f0f2f6}") as demo:
175
+ gr.HTML("""
176
+ <div style="text-align: center; max-width: 800px; margin: 0 auto; padding: 20px;">
177
+ <h1 style="color: #2d333a;">AQuaBot</h1>
178
+ <p style="color: #4a5568;">
179
+ Welcome to AQuaBot - An AI assistant that helps raise awareness about water
180
+ consumption in language models.
181
+ </p>
182
+ </div>
183
+ """)
184
+
185
+ chatbot = gr.Chatbot(type="messages")
186
+ message = gr.Textbox(
187
+ placeholder="Type your message here...",
188
+ show_label=False
189
+ )
190
+ show_water = gr.HTML(f"""
191
+ <div style="position: fixed; top: 20px; right: 20px;
192
+ background-color: white; padding: 15px;
193
+ border: 2px solid #ff0000; border-radius: 10px;
194
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);">
195
+ <div style="color: #ff0000; font-size: 24px; font-weight: bold;">
196
+ 💧 0.0000 ml
197
+ </div>
198
+ <div style="color: #666; font-size: 14px;">
199
+ Water Consumed
200
+ </div>
201
+ </div>
202
+ """)
203
+ clear = gr.Button("Clear Chat")
204
+
205
+ # Add footer with citation and disclaimer
206
+ gr.HTML("""
207
+ <div style="text-align: center; max-width: 800px; margin: 20px auto; padding: 20px;
208
+ background-color: #f8f9fa; border-radius: 10px;">
209
+ <div style="margin-bottom: 15px;">
210
+ <p style="color: #666; font-size: 14px; font-style: italic;">
211
+ Water consumption calculations are based on the study:<br>
212
+ Li, P. et al. (2023). Making AI Less Thirsty: Uncovering and Addressing the Secret Water
213
+ Footprint of AI Models. ArXiv Preprint,
214
+ <a href="https://arxiv.org/abs/2304.03271" target="_blank">https://arxiv.org/abs/2304.03271</a>
215
+ </p>
216
+ </div>
217
+ <div style="border-top: 1px solid #ddd; padding-top: 15px;">
218
+ <p style="color: #666; font-size: 14px;">
219
+ <strong>Important note:</strong> This application uses Microsoft's Phi-2 model
220
+ instead of GPT-3 for availability and cost reasons. However,
221
+ the water consumption calculations per token (input/output) are based on the
222
+ conclusions from the cited paper.
223
+ </p>
224
+ </div>
225
+ </div>
226
+ """)
227
+
228
+ def submit(user_input, chat_history):
229
+ return generate_response(user_input, chat_history)
230
+
231
+ # Configure event handlers
232
+ message.submit(submit, [message, chatbot], [chatbot, show_water])
233
+ clear.click(
234
+ lambda: ([], f"""
235
+ <div style="position: fixed; top: 20px; right: 20px;
236
+ background-color: white; padding: 15px;
237
+ border: 2px solid #ff0000; border-radius: 10px;
238
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);">
239
+ <div style="color: #ff0000; font-size: 24px; font-weight: bold;">
240
+ 💧 0.0000 ml
241
+ </div>
242
+ <div style="color: #666; font-size: 14px;">
243
+ Water Consumed
244
+ </div>
245
+ </div>
246
+ """),
247
+ None,
248
+ [chatbot, show_water]
249
+ )
250
+
251
+ logger.info("Gradio interface created successfully")
252
+
253
+ # Launch the application
254
+ logger.info("Launching application...")
255
+ demo.launch()
256
+
257
+ except Exception as e:
258
+ logger.error(f"Error in Gradio interface creation: {str(e)}")
259
+ raise