File size: 6,767 Bytes
bc32b76 b5fd96f bc32b76 89b3781 9f75433 89b3781 bc32b76 8e1a378 89b3781 8e1a378 bc32b76 89b3781 8069d6f 420d0a9 bc32b76 bf713b8 e5f8a81 bf713b8 bc32b76 bf713b8 bc32b76 bf713b8 b5fd96f bc32b76 bf713b8 bc32b76 bf713b8 bc32b76 bf713b8 bc32b76 b4ff959 bc32b76 b4ff959 bc32b76 b5fd96f bc32b76 b4ff959 bf713b8 bc32b76 bf713b8 bc32b76 bf713b8 bc32b76 b4ff959 bc32b76 b4ff959 bc32b76 b4ff959 bf713b8 bc32b76 bf713b8 b4ff959 bc32b76 bf713b8 bc32b76 bf713b8 e6fa3be bc32b76 e6fa3be 89b3781 e6fa3be bc32b76 e6fa3be bc32b76 e6fa3be b5fd96f e6fa3be |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
# app.py
import gradio as gr
from transformers import LlamaForCausalLM, LlamaTokenizer
import datasets
import torch
import json
import os
import pdfplumber
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from accelerate import Accelerator
import bitsandbytes
import sentencepiece
import huggingface_hub
from transformers import TrainingArguments, Trainer
# Debug: Print all environment variables to verify 'LLama' is present
print("Environment variables:", dict(os.environ))
# Retrieve the token from Hugging Face Space secrets
# Token placement: LLama:levi put token here
LLama = os.getenv("LLama") # Retrieves the value of the 'LLama' environment variable
if not LLama:
raise ValueError("LLama token not found in environment variables. Please set it in Hugging Face Space secrets under 'Settings' > 'Secrets' as 'LLama'.")
# Debug: Print the token to verify it's being read (remove this in production)
print(f"Retrieved LLama token: {LLama[:5]}... (first 5 chars for security)")
# Authenticate with Hugging Face
huggingface_hub.login(token=LLama)
# Model setup
MODEL_ID = "meta-llama/Llama-2-7b-hf"
tokenizer = LlamaTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
# Load model with default attention mechanism (no Flash Attention)
model = LlamaForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
device_map="auto",
load_in_8bit=True
)
# Add padding token if it doesn't exist and resize embeddings
if tokenizer.pad_token is None:
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
model.resize_token_embeddings(len(tokenizer))
# Prepare model for LoRA training
model = prepare_model_for_kbit_training(model)
peft_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"]
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
# Function to process uploaded files and train
def train_ui(files):
try:
# Process multiple PDFs or JSON
raw_text = ""
dataset = None # Initialize dataset as None
for file in files:
if file.name.endswith(".pdf"):
with pdfplumber.open(file.name) as pdf:
for page in pdf.pages:
raw_text += page.extract_text() or ""
elif file.name.endswith(".json"):
with open(file.name, "r", encoding="utf-8") as f:
raw_data = json.load(f)
training_data = raw_data.get("training_pairs", raw_data)
with open("temp_fraud_data.json", "w", encoding="utf-8") as f:
json.dump({"training_pairs": training_data}, f)
dataset = datasets.load_dataset("json", data_files="temp_fraud_data.json")
if not raw_text and not dataset:
return "Error: No valid PDF or JSON data found."
# Create training pairs from PDFs if no JSON
if raw_text:
def create_training_pairs(text):
pairs = []
if "Haloperidol" in text and "daily" in text.lower():
pairs.append({
"input": "Patient received Haloperidol daily. Is this overmedication?",
"output": "Yes, daily Haloperidol use without documented severe psychosis or failed alternatives may indicate overmedication, violating CMS guidelines."
})
if "Lorazepam" in text and "frequent" in text.lower():
pairs.append({
"input": "Care logs show frequent Lorazepam use with a 90-day supply. Is this suspicious?",
"output": "Yes, frequent use with a large supply suggests potential overuse or mismanagement, a fraud indicator."
})
return pairs
training_data = create_training_pairs(raw_text)
with open("temp_fraud_data.json", "w") as f:
json.dump({"training_pairs": training_data}, f)
dataset = datasets.load_dataset("json", data_files="temp_fraud_data.json")
# Tokenization function
def tokenize_data(example):
formatted_text = f"<s>[INST] {example['input']} [/INST] {example['output']}</s>"
inputs = tokenizer(formatted_text, padding="max_length", truncation=True, max_length=4096, return_tensors="pt")
inputs["labels"] = inputs["input_ids"].clone()
return {k: v.squeeze(0) for k, v in inputs.items()}
tokenized_dataset = dataset["train"].map(tokenize_data, batched=True, remove_columns=dataset["train"].column_names)
# Training setup
training_args = TrainingArguments(
output_dir="./fine_tuned_llama_healthcare",
per_device_train_batch_size=4,
gradient_accumulation_steps=8,
eval_strategy="no",
save_strategy="epoch",
save_total_limit=2,
num_train_epochs=5,
learning_rate=2e-5,
weight_decay=0.01,
logging_dir="./logs",
logging_steps=10,
bf16=True,
gradient_checkpointing=True,
optim="adamw_torch",
warmup_steps=100,
)
def custom_data_collator(features):
return {
"input_ids": torch.stack([f["input_ids"] for f in features]),
"attention_mask": torch.stack([f["attention_mask"] for f in features]),
"labels": torch.stack([f["labels"] for f in features]),
}
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
data_collator=custom_data_collator,
)
trainer.train()
model.save_pretrained("./fine_tuned_llama_healthcare")
tokenizer.save_pretrained("./fine_tuned_llama_healthcare")
return "Training completed! Model saved to ./fine_tuned_llama_healthcare"
except Exception as e:
return f"Error: {str(e)}. Please check file format, dependencies, or the LLama token."
# Gradio UI
with gr.Blocks(title="Healthcare Fraud Detection Fine-Tuning") as demo:
gr.Markdown("# Fine-Tune LLaMA 2 for Healthcare Fraud Analysis")
gr.Markdown("Upload PDFs (e.g., care logs, medication records) or a JSON file with training pairs.")
file_input = gr.File(label="Upload Files (PDF/JSON)", file_count="multiple")
train_button = gr.Button("Start Fine-Tuning")
output = gr.Textbox(label="Training Status", lines=5)
train_button.click(fn=train_ui, inputs=file_input, outputs=output)
# Launch the Gradio app
demo.launch() |