baohuynhbk14 commited on
Commit
15d7f87
·
1 Parent(s): 4b99eed

Refactor predict function to remove message parameter and retrieve message from history

Browse files
Files changed (1) hide show
  1. app.py +10 -11
app.py CHANGED
@@ -185,8 +185,7 @@ model = AutoModel.from_pretrained(
185
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=False)
186
 
187
  @spaces.GPU
188
- def predict(message,
189
- image_path,
190
  state,
191
  max_input_tiles=6,
192
  temperature=1.0,
@@ -218,6 +217,7 @@ def predict(message,
218
 
219
  history = state.get_history()
220
  logger.info(f"==== History ====\n{history}")
 
221
 
222
  logger.info(f"==== Lenght Pixel values ====\n{len(pixel_values)}")
223
 
@@ -282,15 +282,14 @@ def ai_bot(
282
  logger.info(f"==== User message ====\n{message}")
283
  logger.info(f"==== Image paths ====\n{all_image_paths}")
284
 
285
- response, _ = predict(message,
286
- all_image_paths[0] if len(all_image_paths) > 0 else None,
287
- state,
288
- max_input_tiles,
289
- temperature,
290
- max_new_tokens,
291
- top_p,
292
- repetition_penalty,
293
- do_sample)
294
 
295
  # response = "This is a test response"
296
  buffer = ""
 
185
  tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=False)
186
 
187
  @spaces.GPU
188
+ def predict(image_path,
 
189
  state,
190
  max_input_tiles=6,
191
  temperature=1.0,
 
217
 
218
  history = state.get_history()
219
  logger.info(f"==== History ====\n{history}")
220
+ message = history[-1][0] if len(history) > 0 else ""
221
 
222
  logger.info(f"==== Lenght Pixel values ====\n{len(pixel_values)}")
223
 
 
282
  logger.info(f"==== User message ====\n{message}")
283
  logger.info(f"==== Image paths ====\n{all_image_paths}")
284
 
285
+ response, _ = predict(state,
286
+ all_image_paths[0] if len(all_image_paths) > 0 else None,
287
+ max_input_tiles,
288
+ temperature,
289
+ max_new_tokens,
290
+ top_p,
291
+ repetition_penalty,
292
+ do_sample)
 
293
 
294
  # response = "This is a test response"
295
  buffer = ""