CamiloVega commited on
Commit
23734f7
·
verified ·
1 Parent(s): 1f5453e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -46
app.py CHANGED
@@ -25,7 +25,7 @@ model_name = "meta-llama/Llama-2-7b-hf"
25
 
26
  try:
27
  logger.info("Starting model initialization...")
28
-
29
  # Check CUDA availability
30
  device = "cuda" if torch.cuda.is_available() else "cpu"
31
  logger.info(f"Using device: {device}")
@@ -45,28 +45,31 @@ try:
45
  tokenizer.pad_token = tokenizer.eos_token
46
  logger.info("Tokenizer loaded successfully")
47
 
48
- # Load model with basic configuration
49
  logger.info("Loading model...")
50
  model = AutoModelForCausalLM.from_pretrained(
51
  model_name,
52
  torch_dtype=torch.float16 if device == "cuda" else torch.float32,
53
  trust_remote_code=True,
54
  token=hf_token,
55
- device_map="auto"
 
 
56
  )
57
  logger.info("Model loaded successfully")
58
 
59
- # Create pipeline
60
  logger.info("Creating generation pipeline...")
61
  model_gen = pipeline(
62
  "text-generation",
63
  model=model,
64
  tokenizer=tokenizer,
65
- max_new_tokens=256,
66
  do_sample=True,
67
- temperature=0.7,
68
- top_p=0.9,
69
- repetition_penalty=1.1,
 
70
  device_map="auto"
71
  )
72
  logger.info("Pipeline created successfully")
@@ -75,9 +78,15 @@ except Exception as e:
75
  logger.error(f"Error during initialization: {str(e)}")
76
  raise
77
 
78
- # Configure system message
 
 
 
 
 
 
79
 
80
- system_message = """You are a helpful AI assistant called AQuaBot. You provide direct, clear, and detailed answers to questions while being aware of environmental impact. Keep your responses natural and informative, but concise. Always provide context and explanations with your answers. Respond directly to questions without using any special tags or markers."""
81
 
82
  @spaces.GPU(duration=60)
83
  @torch.inference_mode()
@@ -90,64 +99,46 @@ def generate_response(user_input, chat_history):
90
  input_water_consumption = calculate_water_consumption(user_input, True)
91
  total_water_consumption += input_water_consumption
92
 
93
- # Create prompt with Llama 2 chat format
94
  conversation_history = ""
95
  if chat_history:
96
- for message in chat_history:
97
- # Remove any [INST] tags from the history
98
- user_msg = message[0].replace("[INST]", "").replace("[/INST]", "").strip()
99
- assistant_msg = message[1].replace("[INST]", "").replace("[/INST]", "").strip()
100
- conversation_history += f"[INST] {user_msg} [/INST] {assistant_msg} "
101
-
102
- prompt = f"<s>[INST] {system_message}\n\n{conversation_history}[INST] {user_input} [/INST]"
103
 
104
  logger.info("Generating model response...")
105
  outputs = model_gen(
106
  prompt,
107
- max_new_tokens=256,
108
  return_full_text=False,
109
  pad_token_id=tokenizer.eos_token_id,
110
- do_sample=True,
111
- temperature=0.7,
112
- top_p=0.9,
113
- repetition_penalty=1.1
114
  )
115
  logger.info("Model response generated successfully")
116
 
117
- # Clean up the response by removing any [INST] tags and trimming
118
  assistant_response = outputs[0]['generated_text'].strip()
119
- assistant_response = assistant_response.replace("[INST]", "").replace("[/INST]", "").strip()
120
 
121
- # If the response is too short, try to generate a more detailed one
122
- if len(assistant_response.split()) < 10:
123
- prompt += "\nPlease provide a more detailed answer with context and explanation."
124
- outputs = model_gen(
125
- prompt,
126
- max_new_tokens=256,
127
- return_full_text=False,
128
- pad_token_id=tokenizer.eos_token_id,
129
- do_sample=True,
130
- temperature=0.7,
131
- top_p=0.9,
132
- repetition_penalty=1.1
133
- )
134
- assistant_response = outputs[0]['generated_text'].strip()
135
- assistant_response = assistant_response.replace("[INST]", "").replace("[/INST]", "").strip()
136
 
137
  # Calculate water consumption for output
138
  output_water_consumption = calculate_water_consumption(assistant_response, False)
139
  total_water_consumption += output_water_consumption
140
 
141
- # Update chat history with the cleaned messages
142
  chat_history.append([user_input, assistant_response])
143
 
144
- # Prepare water consumption message
145
  water_message = f"""
146
  <div style="position: fixed; top: 20px; right: 20px;
147
  background-color: white; padding: 15px;
148
- border: 2px solid #ff0000; border-radius: 10px;
149
- box-shadow: 0 2px 4px rgba(0,0,0,0.1);">
150
- <div style="color: #ff0000; font-size: 24px; font-weight: bold;">
151
  💧 {total_water_consumption:.4f} ml
152
  </div>
153
  <div style="color: #666; font-size: 14px;">
@@ -160,7 +151,7 @@ def generate_response(user_input, chat_history):
160
 
161
  except Exception as e:
162
  logger.error(f"Error in generate_response: {str(e)}")
163
- error_message = f"An error occurred: {str(e)}"
164
  chat_history.append([user_input, error_message])
165
  return chat_history, show_water
166
 
 
25
 
26
  try:
27
  logger.info("Starting model initialization...")
28
+
29
  # Check CUDA availability
30
  device = "cuda" if torch.cuda.is_available() else "cpu"
31
  logger.info(f"Using device: {device}")
 
45
  tokenizer.pad_token = tokenizer.eos_token
46
  logger.info("Tokenizer loaded successfully")
47
 
48
+ # Load model with optimized configuration
49
  logger.info("Loading model...")
50
  model = AutoModelForCausalLM.from_pretrained(
51
  model_name,
52
  torch_dtype=torch.float16 if device == "cuda" else torch.float32,
53
  trust_remote_code=True,
54
  token=hf_token,
55
+ device_map="auto",
56
+ max_memory={0: "12GiB"} if device == "cuda" else None,
57
+ load_in_8bit=True if device == "cuda" else False
58
  )
59
  logger.info("Model loaded successfully")
60
 
61
+ # Create pipeline with improved parameters
62
  logger.info("Creating generation pipeline...")
63
  model_gen = pipeline(
64
  "text-generation",
65
  model=model,
66
  tokenizer=tokenizer,
67
+ max_new_tokens=512, # Increased for more detailed responses
68
  do_sample=True,
69
+ temperature=0.8, # Slightly increased for more creative responses
70
+ top_p=0.95, # Increased for more varied responses
71
+ top_k=50, # Added top_k for better response quality
72
+ repetition_penalty=1.2, # Increased to reduce repetition
73
  device_map="auto"
74
  )
75
  logger.info("Pipeline created successfully")
 
78
  logger.error(f"Error during initialization: {str(e)}")
79
  raise
80
 
81
+ # Improved system message with better context and guidelines
82
+ system_message = """You are AQuaBot, an AI assistant focused on providing accurate and environmentally conscious information. Your responses should be:
83
+ 1. Clear and concise yet informative
84
+ 2. Based on verified information when discussing economic and financial topics
85
+ 3. Balanced and well-reasoned
86
+ 4. Mindful of environmental impact
87
+ 5. Professional but conversational in tone
88
 
89
+ Maintain a helpful and knowledgeable demeanor while avoiding speculation. If you're unsure about something, acknowledge it openly."""
90
 
91
  @spaces.GPU(duration=60)
92
  @torch.inference_mode()
 
99
  input_water_consumption = calculate_water_consumption(user_input, True)
100
  total_water_consumption += input_water_consumption
101
 
102
+ # Create a clean conversation history without [INST] tags
103
  conversation_history = ""
104
  if chat_history:
105
+ for user_msg, assistant_msg in chat_history:
106
+ conversation_history += f"User: {user_msg}\nAssistant: {assistant_msg}\n\n"
107
+
108
+ # Create a clean prompt format
109
+ prompt = f"{system_message}\n\nConversation History:\n{conversation_history}\nUser: {user_input}\nAssistant:"
 
 
110
 
111
  logger.info("Generating model response...")
112
  outputs = model_gen(
113
  prompt,
114
+ max_new_tokens=512,
115
  return_full_text=False,
116
  pad_token_id=tokenizer.eos_token_id,
 
 
 
 
117
  )
118
  logger.info("Model response generated successfully")
119
 
120
+ # Clean up response and remove any remaining [INST] tags
121
  assistant_response = outputs[0]['generated_text'].strip()
122
+ assistant_response = assistant_response.split('User:')[0].split('Assistant:')[-1].strip()
123
 
124
+ # Add fact-check disclaimer for economic/financial responses
125
+ if any(keyword in user_input.lower() for keyword in ['invest', 'money', 'salary', 'cost', 'wage', 'economy']):
126
+ assistant_response += "\n\nNote: Financial information provided should be verified with current market data and professional advisors."
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
  # Calculate water consumption for output
129
  output_water_consumption = calculate_water_consumption(assistant_response, False)
130
  total_water_consumption += output_water_consumption
131
 
132
+ # Update chat history
133
  chat_history.append([user_input, assistant_response])
134
 
135
+ # Prepare water consumption message with improved styling
136
  water_message = f"""
137
  <div style="position: fixed; top: 20px; right: 20px;
138
  background-color: white; padding: 15px;
139
+ border: 2px solid #2196F3; border-radius: 10px;
140
+ box-shadow: 0 4px 6px rgba(0,0,0,0.1);">
141
+ <div style="color: #2196F3; font-size: 24px; font-weight: bold;">
142
  💧 {total_water_consumption:.4f} ml
143
  </div>
144
  <div style="color: #666; font-size: 14px;">
 
151
 
152
  except Exception as e:
153
  logger.error(f"Error in generate_response: {str(e)}")
154
+ error_message = f"I apologize, but I encountered an error. Please try rephrasing your question."
155
  chat_history.append([user_input, error_message])
156
  return chat_history, show_water
157