amos1088 commited on
Commit
c9bee67
·
1 Parent(s): de267c3
Files changed (3) hide show
  1. Dockerfile +18 -15
  2. app.py +41 -57
  3. requirements.txt +1 -2
Dockerfile CHANGED
@@ -1,36 +1,39 @@
1
- FROM nvidia/cuda:12.3.2-cudnn9-devel-ubuntu22.04
 
2
 
3
  ENV DEBIAN_FRONTEND=noninteractive
4
  ENV OMP_NUM_THREADS=4
5
  ENV DISABLE_TRITON=1
6
  ENV ACCELERATE_USE_DEEPSPEED=0
 
 
 
7
 
8
- # Install OS deps
9
  RUN apt-get update && apt-get install -y \
10
- git wget curl python3 python3-pip python3-dev build-essential \
11
- libssl-dev zlib1g-dev libbz2-dev libreadline-dev libsqlite3-dev \
12
  ffmpeg libsm6 libxext6 libgl1-mesa-glx \
13
  && rm -rf /var/lib/apt/lists/*
14
 
15
- # Set Python as default
16
- RUN update-alternatives --install /usr/bin/python python /usr/bin/python3 1
17
  RUN pip install --upgrade pip
18
 
19
- # ---- 1. Install Torch first (needed for flash-attn) ----
20
- RUN pip install --no-cache-dir torch --index-url https://download.pytorch.org/whl/cu121
21
-
22
- # ---- 2. Install requirements (without flash-attn) ----
23
  WORKDIR /app
24
  COPY requirements.txt /app/requirements.txt
25
- RUN pip install --no-cache-dir -r requirements.txt
 
 
 
26
 
27
- # ---- 3. Install flash-attn separately ----
28
- RUN pip install --no-build-isolation flash-attn
29
 
30
- # Copy your code
31
  COPY . /app
32
 
 
33
  EXPOSE 7860
34
 
35
- # Launch Gradio app
36
  CMD ["python", "app.py"]
 
1
+ # Use official PyTorch with CUDA 12.1 (works with flash-attn)
2
+ FROM pytorch/pytorch:2.3.0-cuda12.1-cudnn8-devel
3
 
4
  ENV DEBIAN_FRONTEND=noninteractive
5
  ENV OMP_NUM_THREADS=4
6
  ENV DISABLE_TRITON=1
7
  ENV ACCELERATE_USE_DEEPSPEED=0
8
+ ENV TRANSFORMERS_VERBOSITY=info
9
+ ENV PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
10
+ ENV FLASH_ATTENTION_FORCE=1
11
 
12
+ # Install system dependencies
13
  RUN apt-get update && apt-get install -y \
14
+ git wget curl build-essential python3-dev \
 
15
  ffmpeg libsm6 libxext6 libgl1-mesa-glx \
16
  && rm -rf /var/lib/apt/lists/*
17
 
18
+ # Upgrade pip first
 
19
  RUN pip install --upgrade pip
20
 
21
+ # Copy requirements (without flash-attn)
 
 
 
22
  WORKDIR /app
23
  COPY requirements.txt /app/requirements.txt
24
+ RUN grep -v "flash-attn" requirements.txt > requirements-clean.txt
25
+
26
+ # Install all Python deps except flash-attn
27
+ RUN pip install --no-cache-dir -r requirements-clean.txt
28
 
29
+ # Install flash-attn last to ensure Torch is ready
30
+ RUN pip install --no-build-isolation flash-attn==2.8.2
31
 
32
+ # Copy application
33
  COPY . /app
34
 
35
+ # Expose Gradio
36
  EXPOSE 7860
37
 
38
+ # Default command to launch your app
39
  CMD ["python", "app.py"]
app.py CHANGED
@@ -275,10 +275,7 @@ def get_data_status():
275
  """Get data download status"""
276
  return f"{data_download_status['message']}"
277
 
278
-
279
  def run_inference(query, document_title, document_content, checkpoint="latest"):
280
- import torch
281
-
282
  global current_model, current_tokenizer
283
 
284
  # Load the model if not already loaded
@@ -295,49 +292,32 @@ def run_inference(query, document_title, document_content, checkpoint="latest"):
295
  else:
296
  load_model_and_tokenizer(checkpoint)
297
 
298
- # Prepare prompt exactly like training
299
- prompt = f"""you would get a query and document's title and content and return Relevant/Irrelevant.
300
-
301
- Query:
302
- {query}
303
-
304
- Document:
305
- title: {document_title}
306
- content: {document_content}
307
- """
308
-
309
- # Helper function to score log-probability
310
- def score_response(model, tokenizer, prompt, response):
311
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
312
- labels = tokenizer(response, return_tensors="pt").to(model.device)
313
 
314
- # Concatenate prompt and response (excluding last token of response for shifting)
315
- input_ids = torch.cat([inputs.input_ids, labels.input_ids[:, :-1]], dim=1)
316
- attention_mask = torch.ones_like(input_ids)
317
-
318
- with torch.no_grad():
319
- outputs = model(input_ids=input_ids, attention_mask=attention_mask, use_cache=False)
320
- log_probs = torch.log_softmax(outputs.logits, dim=-1)
321
-
322
- # Compute average log-prob for the response tokens
323
- target_ids = labels.input_ids
324
- seq_logprob = 0
325
- count = 0
326
- for i in range(target_ids.shape[1]):
327
- token_id = target_ids[0, i].item()
328
- if token_id == tokenizer.pad_token_id:
329
- continue
330
- seq_logprob += log_probs[0, inputs.input_ids.shape[1] + i - 1, token_id].item()
331
- count += 1
332
-
333
- return seq_logprob / max(count, 1)
334
 
335
- # Compute log-prob for both options
336
- score_relevant = score_response(current_model, current_tokenizer, prompt, "Relevant")
337
- score_irrelevant = score_response(current_model, current_tokenizer, prompt, "Irrelevant")
338
 
339
- # Return the higher-probability label
340
- return "Relevant" if score_relevant > score_irrelevant else "Irrelevant"
 
341
 
342
 
343
  def list_checkpoints():
@@ -432,19 +412,19 @@ with gr.Blocks(title="Phi-3 DPO Training on BEIR") as demo:
432
 
433
  import time
434
 
435
- def batch_inference(csv_file, checkpoint="latest", batch_size=16):
 
436
  import pandas as pd
437
 
438
  if csv_file is None:
439
  raise ValueError("No CSV file uploaded.")
440
 
441
- # Gradio File can be str (path) or dict
442
  csv_path = csv_file if isinstance(csv_file, str) else getattr(csv_file, "name", None)
443
  if csv_path is None:
444
  raise ValueError("Invalid file input from Gradio.")
445
 
446
  df = pd.read_csv(csv_path)
447
-
448
  if "prompt" not in df.columns:
449
  raise ValueError("CSV must have a 'prompt' column")
450
 
@@ -468,7 +448,7 @@ with gr.Blocks(title="Phi-3 DPO Training on BEIR") as demo:
468
  correct = 0
469
  total = len(prompts)
470
 
471
- # Create temp output CSV
472
  output_path = "/tmp/batch_inference_results.csv"
473
 
474
  for i in range(0, total, batch_size):
@@ -492,40 +472,44 @@ with gr.Blocks(title="Phi-3 DPO Training on BEIR") as demo:
492
  )
493
 
494
  batch_decoded = current_tokenizer.batch_decode(outputs, skip_special_tokens=True)
 
495
  for prompt, decoded in zip(batch_prompts, batch_decoded):
496
- response = decoded[len(prompt):].strip()
497
- if "relevant" in response.lower():
498
- pred = "Relevant" if "irrelevant" not in response.lower() else "Irrelevant"
 
 
499
  else:
500
- pred = response
501
  predictions.append(pred)
502
 
503
- # Optional: compute running accuracy
504
  if "chosen" in df.columns:
505
  for j, pred in enumerate(predictions[-len(batch_prompts):]):
506
  idx = i + j
507
  if str(df["chosen"].iloc[idx]).strip().lower() == pred.lower():
508
  correct += 1
509
 
510
- # Update progress every batch
511
- progress = (i + batch_size) / total * 100
512
  df_partial = df.copy()
513
  df_partial.loc[:len(predictions) - 1, "prediction"] = predictions
514
  df_partial.to_csv(output_path, index=False)
515
 
 
 
516
  stats = f"Processed {min(i + batch_size, total)}/{total} rows ({progress:.1f}%)"
517
  if "chosen" in df.columns:
518
- stats += f"\nCurrent Accuracy: {correct / max(1, len(predictions)) * 100:.2f}%"
519
 
520
- # Yield progress to Gradio
521
  yield output_path, stats
522
 
523
  # Final stats
524
  final_stats = f"✅ Processed {total} rows"
525
  if "chosen" in df.columns:
526
  final_stats += f"\nFinal Accuracy: {correct / total * 100:.2f}%"
527
- yield output_path, final_stats
528
 
 
529
 
530
  csv_infer_btn = gr.Button("Run Batch Inference")
531
  csv_infer_btn.click(
 
275
  """Get data download status"""
276
  return f"{data_download_status['message']}"
277
 
 
278
  def run_inference(query, document_title, document_content, checkpoint="latest"):
 
 
279
  global current_model, current_tokenizer
280
 
281
  # Load the model if not already loaded
 
292
  else:
293
  load_model_and_tokenizer(checkpoint)
294
 
295
+ # Prepare prompt like training
296
+ prompt = format_prompt_for_inference(query, document_title, document_content)
 
 
 
 
 
 
 
 
 
 
 
 
 
297
 
298
+ # Tokenize
299
+ inputs = current_tokenizer(
300
+ prompt, return_tensors="pt", truncation=True, max_length=512
301
+ )
302
+ inputs = {k: v.to(current_model.device) for k, v in inputs.items()}
303
+
304
+ # Generate single label
305
+ with torch.no_grad():
306
+ outputs = current_model.generate(
307
+ **inputs,
308
+ max_new_tokens=5,
309
+ temperature=0.0,
310
+ do_sample=False,
311
+ pad_token_id=current_tokenizer.eos_token_id,
312
+ use_cache=False
313
+ )
 
 
 
 
314
 
315
+ response = current_tokenizer.decode(outputs[0], skip_special_tokens=True)
316
+ response = response[len(prompt):].strip().lower()
 
317
 
318
+ if response.startswith("irrelevant"):
319
+ return "Irrelevant"
320
+ return "Relevant"
321
 
322
 
323
  def list_checkpoints():
 
412
 
413
  import time
414
 
415
+
416
+ def batch_inference(csv_file, checkpoint="latest", batch_size=64):
417
  import pandas as pd
418
 
419
  if csv_file is None:
420
  raise ValueError("No CSV file uploaded.")
421
 
422
+ # Gradio File can be path (str) or tempfile object
423
  csv_path = csv_file if isinstance(csv_file, str) else getattr(csv_file, "name", None)
424
  if csv_path is None:
425
  raise ValueError("Invalid file input from Gradio.")
426
 
427
  df = pd.read_csv(csv_path)
 
428
  if "prompt" not in df.columns:
429
  raise ValueError("CSV must have a 'prompt' column")
430
 
 
448
  correct = 0
449
  total = len(prompts)
450
 
451
+ # Temp output path
452
  output_path = "/tmp/batch_inference_results.csv"
453
 
454
  for i in range(0, total, batch_size):
 
472
  )
473
 
474
  batch_decoded = current_tokenizer.batch_decode(outputs, skip_special_tokens=True)
475
+
476
  for prompt, decoded in zip(batch_prompts, batch_decoded):
477
+ response = decoded[len(prompt):].strip().lower()
478
+ if response.startswith("irrelevant"):
479
+ pred = "Irrelevant"
480
+ elif response.startswith("relevant"):
481
+ pred = "Relevant"
482
  else:
483
+ pred = decoded.strip()
484
  predictions.append(pred)
485
 
486
+ # Accuracy calculation
487
  if "chosen" in df.columns:
488
  for j, pred in enumerate(predictions[-len(batch_prompts):]):
489
  idx = i + j
490
  if str(df["chosen"].iloc[idx]).strip().lower() == pred.lower():
491
  correct += 1
492
 
493
+ # Save partial results for streaming
 
494
  df_partial = df.copy()
495
  df_partial.loc[:len(predictions) - 1, "prediction"] = predictions
496
  df_partial.to_csv(output_path, index=False)
497
 
498
+ # Progress & accuracy stats
499
+ progress = min(i + batch_size, total) / total * 100
500
  stats = f"Processed {min(i + batch_size, total)}/{total} rows ({progress:.1f}%)"
501
  if "chosen" in df.columns:
502
+ stats += f"\nCurrent Accuracy: {correct / len(predictions) * 100:.2f}%"
503
 
504
+ # Stream update to Gradio
505
  yield output_path, stats
506
 
507
  # Final stats
508
  final_stats = f"✅ Processed {total} rows"
509
  if "chosen" in df.columns:
510
  final_stats += f"\nFinal Accuracy: {correct / total * 100:.2f}%"
 
511
 
512
+ yield output_path, final_stats
513
 
514
  csv_infer_btn = gr.Button("Run Batch Inference")
515
  csv_infer_btn.click(
requirements.txt CHANGED
@@ -6,8 +6,7 @@ accelerate>=0.25.0
6
  bitsandbytes>=0.41.0
7
  datasets
8
  pandas
9
- torch>=2.0.0
10
  scipy
11
  beir
12
  scikit-learn
13
- tqdm
 
6
  bitsandbytes>=0.41.0
7
  datasets
8
  pandas
 
9
  scipy
10
  beir
11
  scikit-learn
12
+ tqdm