Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer | |
import gradio as gr | |
import torch | |
import logging | |
import sys | |
import os | |
from accelerate import infer_auto_device_map, init_empty_weights | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(levelname)s - %(message)s' | |
) | |
logger = logging.getLogger(__name__) | |
# Get HuggingFace token from environment variable | |
hf_token = os.environ.get('HUGGINGFACE_TOKEN') | |
if not hf_token: | |
logger.error("HUGGINGFACE_TOKEN environment variable not set") | |
raise ValueError("Please set the HUGGINGFACE_TOKEN environment variable") | |
# Define the model name | |
model_name = "meta-llama/Llama-2-7b-hf" | |
try: | |
logger.info("Starting model initialization...") | |
# Check CUDA availability | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
logger.info(f"Using device: {device}") | |
# Configure PyTorch settings | |
if device == "cuda": | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.allow_tf32 = True | |
# Load tokenizer | |
logger.info("Loading tokenizer...") | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_name, | |
trust_remote_code=True, | |
token=hf_token | |
) | |
tokenizer.pad_token = tokenizer.eos_token | |
logger.info("Tokenizer loaded successfully") | |
# Load model with optimized configuration | |
logger.info("Loading model...") | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype=torch.float16 if device == "cuda" else torch.float32, | |
trust_remote_code=True, | |
token=hf_token, | |
device_map="auto", | |
max_memory={0: "12GiB"} if device == "cuda" else None, | |
load_in_8bit=True if device == "cuda" else False | |
) | |
logger.info("Model loaded successfully") | |
# Create pipeline with improved parameters | |
logger.info("Creating generation pipeline...") | |
model_gen = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
max_new_tokens=512, # Increased for more detailed responses | |
do_sample=True, | |
temperature=0.8, # Slightly increased for more creative responses | |
top_p=0.95, # Increased for more varied responses | |
top_k=50, # Added top_k for better response quality | |
repetition_penalty=1.2, # Increased to reduce repetition | |
device_map="auto" | |
) | |
logger.info("Pipeline created successfully") | |
except Exception as e: | |
logger.error(f"Error during initialization: {str(e)}") | |
raise | |
# Improved system message with better context and guidelines | |
system_message = """You are AQuaBot, an AI assistant focused on providing accurate and environmentally conscious information. Your responses should be: | |
1. Clear and concise yet informative | |
2. Based on verified information when discussing economic and financial topics | |
3. Balanced and well-reasoned | |
4. Mindful of environmental impact | |
5. Professional but conversational in tone | |
Maintain a helpful and knowledgeable demeanor while avoiding speculation. If you're unsure about something, acknowledge it openly.""" | |
def generate_response(user_input, chat_history): | |
try: | |
logger.info("Generating response for user input...") | |
global total_water_consumption | |
# Calculate water consumption for input | |
input_water_consumption = calculate_water_consumption(user_input, True) | |
total_water_consumption += input_water_consumption | |
# Create a clean conversation history without [INST] tags | |
conversation_history = "" | |
if chat_history: | |
for user_msg, assistant_msg in chat_history: | |
conversation_history += f"User: {user_msg}\nAssistant: {assistant_msg}\n\n" | |
# Create a clean prompt format | |
prompt = f"{system_message}\n\nConversation History:\n{conversation_history}\nUser: {user_input}\nAssistant:" | |
logger.info("Generating model response...") | |
outputs = model_gen( | |
prompt, | |
max_new_tokens=512, | |
return_full_text=False, | |
pad_token_id=tokenizer.eos_token_id, | |
) | |
logger.info("Model response generated successfully") | |
# Clean up response and remove any remaining [INST] tags | |
assistant_response = outputs[0]['generated_text'].strip() | |
assistant_response = assistant_response.split('User:')[0].split('Assistant:')[-1].strip() | |
# Add fact-check disclaimer for economic/financial responses | |
if any(keyword in user_input.lower() for keyword in ['invest', 'money', 'salary', 'cost', 'wage', 'economy']): | |
assistant_response += "\n\nNote: Financial information provided should be verified with current market data and professional advisors." | |
# Calculate water consumption for output | |
output_water_consumption = calculate_water_consumption(assistant_response, False) | |
total_water_consumption += output_water_consumption | |
# Update chat history | |
chat_history.append([user_input, assistant_response]) | |
# Prepare water consumption message with improved styling | |
water_message = f""" | |
<div style="position: fixed; top: 20px; right: 20px; | |
background-color: white; padding: 15px; | |
border: 2px solid #2196F3; border-radius: 10px; | |
box-shadow: 0 4px 6px rgba(0,0,0,0.1);"> | |
<div style="color: #2196F3; font-size: 24px; font-weight: bold;"> | |
π§ {total_water_consumption:.4f} ml | |
</div> | |
<div style="color: #666; font-size: 14px;"> | |
Water Consumed | |
</div> | |
</div> | |
""" | |
return chat_history, water_message | |
except Exception as e: | |
logger.error(f"Error in generate_response: {str(e)}") | |
error_message = f"I apologize, but I encountered an error. Please try rephrasing your question." | |
chat_history.append([user_input, error_message]) | |
return chat_history, show_water | |
# Constants for water consumption calculation | |
WATER_PER_TOKEN = { | |
"input_training": 0.0000309, | |
"output_training": 0.0000309, | |
"input_inference": 0.05, | |
"output_inference": 0.05 | |
} | |
# Initialize variables | |
total_water_consumption = 0 | |
def calculate_tokens(text): | |
try: | |
return len(tokenizer.encode(text)) | |
except Exception as e: | |
logger.error(f"Error calculating tokens: {str(e)}") | |
return len(text.split()) + len(text) // 4 # Fallback to approximation | |
def calculate_water_consumption(text, is_input=True): | |
tokens = calculate_tokens(text) | |
if is_input: | |
return tokens * (WATER_PER_TOKEN["input_training"] + WATER_PER_TOKEN["input_inference"]) | |
return tokens * (WATER_PER_TOKEN["output_training"] + WATER_PER_TOKEN["output_inference"]) | |
def format_message(role, content): | |
return {"role": role, "content": content} | |
def generate_response(user_input, chat_history): | |
try: | |
logger.info("Generating response for user input...") | |
global total_water_consumption | |
# Calculate water consumption for input | |
input_water_consumption = calculate_water_consumption(user_input, True) | |
total_water_consumption += input_water_consumption | |
# Create prompt with Llama 2 chat format | |
conversation_history = "" | |
if chat_history: | |
for message in chat_history: | |
conversation_history += f"[INST] {message[0]} [/INST] {message[1]} " | |
prompt = f"<s>[INST] {system_message}\n\n{conversation_history}[INST] {user_input} [/INST]" | |
logger.info("Generating model response...") | |
outputs = model_gen( | |
prompt, | |
max_new_tokens=256, | |
return_full_text=False, | |
pad_token_id=tokenizer.eos_token_id, | |
) | |
logger.info("Model response generated successfully") | |
assistant_response = outputs[0]['generated_text'].strip() | |
# Calculate water consumption for output | |
output_water_consumption = calculate_water_consumption(assistant_response, False) | |
total_water_consumption += output_water_consumption | |
# Update chat history with the new formatted messages | |
chat_history.append([user_input, assistant_response]) | |
# Prepare water consumption message | |
water_message = f""" | |
<div style="position: fixed; top: 20px; right: 20px; | |
background-color: white; padding: 15px; | |
border: 2px solid #ff0000; border-radius: 10px; | |
box-shadow: 0 2px 4px rgba(0,0,0,0.1);"> | |
<div style="color: #ff0000; font-size: 24px; font-weight: bold;"> | |
π§ {total_water_consumption:.4f} ml | |
</div> | |
<div style="color: #666; font-size: 14px;"> | |
Water Consumed | |
</div> | |
</div> | |
""" | |
return chat_history, water_message | |
except Exception as e: | |
logger.error(f"Error in generate_response: {str(e)}") | |
error_message = f"An error occurred: {str(e)}" | |
chat_history.append([user_input, error_message]) | |
return chat_history, show_water | |
# Create Gradio interface | |
try: | |
logger.info("Creating Gradio interface...") | |
with gr.Blocks(css="div.gradio-container {background-color: #f0f2f6}") as demo: | |
gr.HTML(""" | |
<div style="text-align: center; max-width: 800px; margin: 0 auto; padding: 20px;"> | |
<h1 style="color: #2d333a;">AQuaBot</h1> | |
<p style="color: #4a5568;"> | |
Welcome to AQuaBot - An AI assistant that helps raise awareness | |
about water consumption in language models. | |
</p> | |
</div> | |
""") | |
chatbot = gr.Chatbot() | |
message = gr.Textbox( | |
placeholder="Type your message here...", | |
show_label=False | |
) | |
show_water = gr.HTML(f""" | |
<div style="position: fixed; top: 20px; right: 20px; | |
background-color: white; padding: 15px; | |
border: 2px solid #ff0000; border-radius: 10px; | |
box-shadow: 0 2px 4px rgba(0,0,0,0.1);"> | |
<div style="color: #ff0000; font-size: 24px; font-weight: bold;"> | |
π§ 0.0000 ml | |
</div> | |
<div style="color: #666; font-size: 14px;"> | |
Water Consumed | |
</div> | |
</div> | |
""") | |
clear = gr.Button("Clear Chat") | |
# Add footer with citation and disclaimer | |
gr.HTML(""" | |
<div style="text-align: center; max-width: 800px; margin: 20px auto; padding: 20px; | |
background-color: #f8f9fa; border-radius: 10px;"> | |
<div style="margin-bottom: 15px;"> | |
<p style="color: #666; font-size: 14px; font-style: italic;"> | |
Water consumption calculations are based on the study:<br> | |
Li, P. et al. (2023). Making AI Less Thirsty: Uncovering and Addressing the Secret Water | |
Footprint of AI Models. ArXiv Preprint, | |
<a href="https://arxiv.org/abs/2304.03271" target="_blank">https://arxiv.org/abs/2304.03271</a> | |
</p> | |
</div> | |
<div style="border-top: 1px solid #ddd; padding-top: 15px;"> | |
<p style="color: #666; font-size: 14px;"> | |
<strong>Important note:</strong> This application uses Meta Llama-2-7b model | |
instead of GPT-3 for availability and cost reasons. However, | |
the water consumption calculations per token (input/output) are based on the | |
conclusions from the cited paper. | |
</p> | |
</div> | |
</div> | |
""") | |
def submit(user_input, chat_history): | |
return generate_response(user_input, chat_history) | |
# Configure event handlers | |
message.submit(submit, [message, chatbot], [chatbot, show_water]) | |
clear.click( | |
lambda: ([], f""" | |
<div style="position: fixed; top: 20px; right: 20px; | |
background-color: white; padding: 15px; | |
border: 2px solid #ff0000; border-radius: 10px; | |
box-shadow: 0 2px 4px rgba(0,0,0,0.1);"> | |
<div style="color: #ff0000; font-size: 24px; font-weight: bold;"> | |
π§ 0.0000 ml | |
</div> | |
<div style="color: #666; font-size: 14px;"> | |
Water Consumed | |
</div> | |
</div> | |
"""), | |
None, | |
[chatbot, show_water] | |
) | |
logger.info("Gradio interface created successfully") | |
# Launch the application | |
logger.info("Launching application...") | |
demo.launch() | |
except Exception as e: | |
logger.error(f"Error in Gradio interface creation: {str(e)}") | |
raise |