Spaces:
Paused
Paused
tt
Browse files- Dockerfile +18 -15
- app.py +41 -57
- requirements.txt +1 -2
Dockerfile
CHANGED
@@ -1,36 +1,39 @@
|
|
1 |
-
|
|
|
2 |
|
3 |
ENV DEBIAN_FRONTEND=noninteractive
|
4 |
ENV OMP_NUM_THREADS=4
|
5 |
ENV DISABLE_TRITON=1
|
6 |
ENV ACCELERATE_USE_DEEPSPEED=0
|
|
|
|
|
|
|
7 |
|
8 |
-
# Install
|
9 |
RUN apt-get update && apt-get install -y \
|
10 |
-
git wget curl
|
11 |
-
libssl-dev zlib1g-dev libbz2-dev libreadline-dev libsqlite3-dev \
|
12 |
ffmpeg libsm6 libxext6 libgl1-mesa-glx \
|
13 |
&& rm -rf /var/lib/apt/lists/*
|
14 |
|
15 |
-
#
|
16 |
-
RUN update-alternatives --install /usr/bin/python python /usr/bin/python3 1
|
17 |
RUN pip install --upgrade pip
|
18 |
|
19 |
-
#
|
20 |
-
RUN pip install --no-cache-dir torch --index-url https://download.pytorch.org/whl/cu121
|
21 |
-
|
22 |
-
# ---- 2. Install requirements (without flash-attn) ----
|
23 |
WORKDIR /app
|
24 |
COPY requirements.txt /app/requirements.txt
|
25 |
-
RUN
|
|
|
|
|
|
|
26 |
|
27 |
-
#
|
28 |
-
RUN pip install --no-build-isolation flash-attn
|
29 |
|
30 |
-
# Copy
|
31 |
COPY . /app
|
32 |
|
|
|
33 |
EXPOSE 7860
|
34 |
|
35 |
-
#
|
36 |
CMD ["python", "app.py"]
|
|
|
1 |
+
# Use official PyTorch with CUDA 12.1 (works with flash-attn)
|
2 |
+
FROM pytorch/pytorch:2.3.0-cuda12.1-cudnn8-devel
|
3 |
|
4 |
ENV DEBIAN_FRONTEND=noninteractive
|
5 |
ENV OMP_NUM_THREADS=4
|
6 |
ENV DISABLE_TRITON=1
|
7 |
ENV ACCELERATE_USE_DEEPSPEED=0
|
8 |
+
ENV TRANSFORMERS_VERBOSITY=info
|
9 |
+
ENV PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
|
10 |
+
ENV FLASH_ATTENTION_FORCE=1
|
11 |
|
12 |
+
# Install system dependencies
|
13 |
RUN apt-get update && apt-get install -y \
|
14 |
+
git wget curl build-essential python3-dev \
|
|
|
15 |
ffmpeg libsm6 libxext6 libgl1-mesa-glx \
|
16 |
&& rm -rf /var/lib/apt/lists/*
|
17 |
|
18 |
+
# Upgrade pip first
|
|
|
19 |
RUN pip install --upgrade pip
|
20 |
|
21 |
+
# Copy requirements (without flash-attn)
|
|
|
|
|
|
|
22 |
WORKDIR /app
|
23 |
COPY requirements.txt /app/requirements.txt
|
24 |
+
RUN grep -v "flash-attn" requirements.txt > requirements-clean.txt
|
25 |
+
|
26 |
+
# Install all Python deps except flash-attn
|
27 |
+
RUN pip install --no-cache-dir -r requirements-clean.txt
|
28 |
|
29 |
+
# Install flash-attn last to ensure Torch is ready
|
30 |
+
RUN pip install --no-build-isolation flash-attn==2.8.2
|
31 |
|
32 |
+
# Copy application
|
33 |
COPY . /app
|
34 |
|
35 |
+
# Expose Gradio
|
36 |
EXPOSE 7860
|
37 |
|
38 |
+
# Default command to launch your app
|
39 |
CMD ["python", "app.py"]
|
app.py
CHANGED
@@ -275,10 +275,7 @@ def get_data_status():
|
|
275 |
"""Get data download status"""
|
276 |
return f"{data_download_status['message']}"
|
277 |
|
278 |
-
|
279 |
def run_inference(query, document_title, document_content, checkpoint="latest"):
|
280 |
-
import torch
|
281 |
-
|
282 |
global current_model, current_tokenizer
|
283 |
|
284 |
# Load the model if not already loaded
|
@@ -295,49 +292,32 @@ def run_inference(query, document_title, document_content, checkpoint="latest"):
|
|
295 |
else:
|
296 |
load_model_and_tokenizer(checkpoint)
|
297 |
|
298 |
-
# Prepare prompt
|
299 |
-
prompt =
|
300 |
-
|
301 |
-
Query:
|
302 |
-
{query}
|
303 |
-
|
304 |
-
Document:
|
305 |
-
title: {document_title}
|
306 |
-
content: {document_content}
|
307 |
-
"""
|
308 |
-
|
309 |
-
# Helper function to score log-probability
|
310 |
-
def score_response(model, tokenizer, prompt, response):
|
311 |
-
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
312 |
-
labels = tokenizer(response, return_tensors="pt").to(model.device)
|
313 |
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
seq_logprob += log_probs[0, inputs.input_ids.shape[1] + i - 1, token_id].item()
|
331 |
-
count += 1
|
332 |
-
|
333 |
-
return seq_logprob / max(count, 1)
|
334 |
|
335 |
-
|
336 |
-
|
337 |
-
score_irrelevant = score_response(current_model, current_tokenizer, prompt, "Irrelevant")
|
338 |
|
339 |
-
|
340 |
-
|
|
|
341 |
|
342 |
|
343 |
def list_checkpoints():
|
@@ -432,19 +412,19 @@ with gr.Blocks(title="Phi-3 DPO Training on BEIR") as demo:
|
|
432 |
|
433 |
import time
|
434 |
|
435 |
-
|
|
|
436 |
import pandas as pd
|
437 |
|
438 |
if csv_file is None:
|
439 |
raise ValueError("No CSV file uploaded.")
|
440 |
|
441 |
-
# Gradio File can be
|
442 |
csv_path = csv_file if isinstance(csv_file, str) else getattr(csv_file, "name", None)
|
443 |
if csv_path is None:
|
444 |
raise ValueError("Invalid file input from Gradio.")
|
445 |
|
446 |
df = pd.read_csv(csv_path)
|
447 |
-
|
448 |
if "prompt" not in df.columns:
|
449 |
raise ValueError("CSV must have a 'prompt' column")
|
450 |
|
@@ -468,7 +448,7 @@ with gr.Blocks(title="Phi-3 DPO Training on BEIR") as demo:
|
|
468 |
correct = 0
|
469 |
total = len(prompts)
|
470 |
|
471 |
-
#
|
472 |
output_path = "/tmp/batch_inference_results.csv"
|
473 |
|
474 |
for i in range(0, total, batch_size):
|
@@ -492,40 +472,44 @@ with gr.Blocks(title="Phi-3 DPO Training on BEIR") as demo:
|
|
492 |
)
|
493 |
|
494 |
batch_decoded = current_tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
|
|
495 |
for prompt, decoded in zip(batch_prompts, batch_decoded):
|
496 |
-
response = decoded[len(prompt):].strip()
|
497 |
-
if
|
498 |
-
pred = "
|
|
|
|
|
499 |
else:
|
500 |
-
pred =
|
501 |
predictions.append(pred)
|
502 |
|
503 |
-
#
|
504 |
if "chosen" in df.columns:
|
505 |
for j, pred in enumerate(predictions[-len(batch_prompts):]):
|
506 |
idx = i + j
|
507 |
if str(df["chosen"].iloc[idx]).strip().lower() == pred.lower():
|
508 |
correct += 1
|
509 |
|
510 |
-
#
|
511 |
-
progress = (i + batch_size) / total * 100
|
512 |
df_partial = df.copy()
|
513 |
df_partial.loc[:len(predictions) - 1, "prediction"] = predictions
|
514 |
df_partial.to_csv(output_path, index=False)
|
515 |
|
|
|
|
|
516 |
stats = f"Processed {min(i + batch_size, total)}/{total} rows ({progress:.1f}%)"
|
517 |
if "chosen" in df.columns:
|
518 |
-
stats += f"\nCurrent Accuracy: {correct /
|
519 |
|
520 |
-
#
|
521 |
yield output_path, stats
|
522 |
|
523 |
# Final stats
|
524 |
final_stats = f"✅ Processed {total} rows"
|
525 |
if "chosen" in df.columns:
|
526 |
final_stats += f"\nFinal Accuracy: {correct / total * 100:.2f}%"
|
527 |
-
yield output_path, final_stats
|
528 |
|
|
|
529 |
|
530 |
csv_infer_btn = gr.Button("Run Batch Inference")
|
531 |
csv_infer_btn.click(
|
|
|
275 |
"""Get data download status"""
|
276 |
return f"{data_download_status['message']}"
|
277 |
|
|
|
278 |
def run_inference(query, document_title, document_content, checkpoint="latest"):
|
|
|
|
|
279 |
global current_model, current_tokenizer
|
280 |
|
281 |
# Load the model if not already loaded
|
|
|
292 |
else:
|
293 |
load_model_and_tokenizer(checkpoint)
|
294 |
|
295 |
+
# Prepare prompt like training
|
296 |
+
prompt = format_prompt_for_inference(query, document_title, document_content)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
297 |
|
298 |
+
# Tokenize
|
299 |
+
inputs = current_tokenizer(
|
300 |
+
prompt, return_tensors="pt", truncation=True, max_length=512
|
301 |
+
)
|
302 |
+
inputs = {k: v.to(current_model.device) for k, v in inputs.items()}
|
303 |
+
|
304 |
+
# Generate single label
|
305 |
+
with torch.no_grad():
|
306 |
+
outputs = current_model.generate(
|
307 |
+
**inputs,
|
308 |
+
max_new_tokens=5,
|
309 |
+
temperature=0.0,
|
310 |
+
do_sample=False,
|
311 |
+
pad_token_id=current_tokenizer.eos_token_id,
|
312 |
+
use_cache=False
|
313 |
+
)
|
|
|
|
|
|
|
|
|
314 |
|
315 |
+
response = current_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
316 |
+
response = response[len(prompt):].strip().lower()
|
|
|
317 |
|
318 |
+
if response.startswith("irrelevant"):
|
319 |
+
return "Irrelevant"
|
320 |
+
return "Relevant"
|
321 |
|
322 |
|
323 |
def list_checkpoints():
|
|
|
412 |
|
413 |
import time
|
414 |
|
415 |
+
|
416 |
+
def batch_inference(csv_file, checkpoint="latest", batch_size=64):
|
417 |
import pandas as pd
|
418 |
|
419 |
if csv_file is None:
|
420 |
raise ValueError("No CSV file uploaded.")
|
421 |
|
422 |
+
# Gradio File can be path (str) or tempfile object
|
423 |
csv_path = csv_file if isinstance(csv_file, str) else getattr(csv_file, "name", None)
|
424 |
if csv_path is None:
|
425 |
raise ValueError("Invalid file input from Gradio.")
|
426 |
|
427 |
df = pd.read_csv(csv_path)
|
|
|
428 |
if "prompt" not in df.columns:
|
429 |
raise ValueError("CSV must have a 'prompt' column")
|
430 |
|
|
|
448 |
correct = 0
|
449 |
total = len(prompts)
|
450 |
|
451 |
+
# Temp output path
|
452 |
output_path = "/tmp/batch_inference_results.csv"
|
453 |
|
454 |
for i in range(0, total, batch_size):
|
|
|
472 |
)
|
473 |
|
474 |
batch_decoded = current_tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
475 |
+
|
476 |
for prompt, decoded in zip(batch_prompts, batch_decoded):
|
477 |
+
response = decoded[len(prompt):].strip().lower()
|
478 |
+
if response.startswith("irrelevant"):
|
479 |
+
pred = "Irrelevant"
|
480 |
+
elif response.startswith("relevant"):
|
481 |
+
pred = "Relevant"
|
482 |
else:
|
483 |
+
pred = decoded.strip()
|
484 |
predictions.append(pred)
|
485 |
|
486 |
+
# Accuracy calculation
|
487 |
if "chosen" in df.columns:
|
488 |
for j, pred in enumerate(predictions[-len(batch_prompts):]):
|
489 |
idx = i + j
|
490 |
if str(df["chosen"].iloc[idx]).strip().lower() == pred.lower():
|
491 |
correct += 1
|
492 |
|
493 |
+
# Save partial results for streaming
|
|
|
494 |
df_partial = df.copy()
|
495 |
df_partial.loc[:len(predictions) - 1, "prediction"] = predictions
|
496 |
df_partial.to_csv(output_path, index=False)
|
497 |
|
498 |
+
# Progress & accuracy stats
|
499 |
+
progress = min(i + batch_size, total) / total * 100
|
500 |
stats = f"Processed {min(i + batch_size, total)}/{total} rows ({progress:.1f}%)"
|
501 |
if "chosen" in df.columns:
|
502 |
+
stats += f"\nCurrent Accuracy: {correct / len(predictions) * 100:.2f}%"
|
503 |
|
504 |
+
# Stream update to Gradio
|
505 |
yield output_path, stats
|
506 |
|
507 |
# Final stats
|
508 |
final_stats = f"✅ Processed {total} rows"
|
509 |
if "chosen" in df.columns:
|
510 |
final_stats += f"\nFinal Accuracy: {correct / total * 100:.2f}%"
|
|
|
511 |
|
512 |
+
yield output_path, final_stats
|
513 |
|
514 |
csv_infer_btn = gr.Button("Run Batch Inference")
|
515 |
csv_infer_btn.click(
|
requirements.txt
CHANGED
@@ -6,8 +6,7 @@ accelerate>=0.25.0
|
|
6 |
bitsandbytes>=0.41.0
|
7 |
datasets
|
8 |
pandas
|
9 |
-
torch>=2.0.0
|
10 |
scipy
|
11 |
beir
|
12 |
scikit-learn
|
13 |
-
tqdm
|
|
|
6 |
bitsandbytes>=0.41.0
|
7 |
datasets
|
8 |
pandas
|
|
|
9 |
scipy
|
10 |
beir
|
11 |
scikit-learn
|
12 |
+
tqdm
|