File size: 6,767 Bytes
bc32b76
 
 
 
 
 
 
 
 
 
 
 
 
 
b5fd96f
bc32b76
89b3781
 
 
9f75433
89b3781
 
 
 
bc32b76
8e1a378
89b3781
8e1a378
bc32b76
89b3781
8069d6f
420d0a9
bc32b76
 
bf713b8
e5f8a81
bf713b8
 
bc32b76
bf713b8
bc32b76
bf713b8
 
b5fd96f
 
 
 
 
bc32b76
bf713b8
 
bc32b76
 
 
bf713b8
 
bc32b76
bf713b8
 
bc32b76
b4ff959
bc32b76
 
b4ff959
bc32b76
 
b5fd96f
bc32b76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b4ff959
bf713b8
bc32b76
bf713b8
 
bc32b76
bf713b8
bc32b76
 
b4ff959
bc32b76
 
 
 
b4ff959
 
bc32b76
b4ff959
 
bf713b8
 
bc32b76
 
bf713b8
 
b4ff959
bc32b76
bf713b8
bc32b76
bf713b8
e6fa3be
 
 
 
 
 
 
 
 
 
 
bc32b76
 
 
 
e6fa3be
89b3781
e6fa3be
 
bc32b76
 
 
 
e6fa3be
bc32b76
 
e6fa3be
b5fd96f
e6fa3be
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
# app.py

import gradio as gr
from transformers import LlamaForCausalLM, LlamaTokenizer
import datasets
import torch
import json
import os
import pdfplumber
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from accelerate import Accelerator
import bitsandbytes
import sentencepiece
import huggingface_hub
from transformers import TrainingArguments, Trainer

# Debug: Print all environment variables to verify 'LLama' is present
print("Environment variables:", dict(os.environ))

# Retrieve the token from Hugging Face Space secrets
# Token placement: LLama:levi put token here
LLama = os.getenv("LLama")  # Retrieves the value of the 'LLama' environment variable
if not LLama:
    raise ValueError("LLama token not found in environment variables. Please set it in Hugging Face Space secrets under 'Settings' > 'Secrets' as 'LLama'.")

# Debug: Print the token to verify it's being read (remove this in production)
print(f"Retrieved LLama token: {LLama[:5]}... (first 5 chars for security)")

# Authenticate with Hugging Face
huggingface_hub.login(token=LLama)

# Model setup
MODEL_ID = "meta-llama/Llama-2-7b-hf"
tokenizer = LlamaTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)

# Load model with default attention mechanism (no Flash Attention)
model = LlamaForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    load_in_8bit=True
)

# Add padding token if it doesn't exist and resize embeddings
if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({'pad_token': '[PAD]'})
    model.resize_token_embeddings(len(tokenizer))

# Prepare model for LoRA training
model = prepare_model_for_kbit_training(model)
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"]
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

# Function to process uploaded files and train
def train_ui(files):
    try:
        # Process multiple PDFs or JSON
        raw_text = ""
        dataset = None  # Initialize dataset as None
        for file in files:
            if file.name.endswith(".pdf"):
                with pdfplumber.open(file.name) as pdf:
                    for page in pdf.pages:
                        raw_text += page.extract_text() or ""
            elif file.name.endswith(".json"):
                with open(file.name, "r", encoding="utf-8") as f:
                    raw_data = json.load(f)
                    training_data = raw_data.get("training_pairs", raw_data)
                    with open("temp_fraud_data.json", "w", encoding="utf-8") as f:
                        json.dump({"training_pairs": training_data}, f)
                    dataset = datasets.load_dataset("json", data_files="temp_fraud_data.json")

        if not raw_text and not dataset:
            return "Error: No valid PDF or JSON data found."

        # Create training pairs from PDFs if no JSON
        if raw_text:
            def create_training_pairs(text):
                pairs = []
                if "Haloperidol" in text and "daily" in text.lower():
                    pairs.append({
                        "input": "Patient received Haloperidol daily. Is this overmedication?",
                        "output": "Yes, daily Haloperidol use without documented severe psychosis or failed alternatives may indicate overmedication, violating CMS guidelines."
                    })
                if "Lorazepam" in text and "frequent" in text.lower():
                    pairs.append({
                        "input": "Care logs show frequent Lorazepam use with a 90-day supply. Is this suspicious?",
                        "output": "Yes, frequent use with a large supply suggests potential overuse or mismanagement, a fraud indicator."
                    })
                return pairs
            training_data = create_training_pairs(raw_text)
            with open("temp_fraud_data.json", "w") as f:
                json.dump({"training_pairs": training_data}, f)
            dataset = datasets.load_dataset("json", data_files="temp_fraud_data.json")

        # Tokenization function
        def tokenize_data(example):
            formatted_text = f"<s>[INST] {example['input']} [/INST] {example['output']}</s>"
            inputs = tokenizer(formatted_text, padding="max_length", truncation=True, max_length=4096, return_tensors="pt")
            inputs["labels"] = inputs["input_ids"].clone()
            return {k: v.squeeze(0) for k, v in inputs.items()}

        tokenized_dataset = dataset["train"].map(tokenize_data, batched=True, remove_columns=dataset["train"].column_names)

        # Training setup
        training_args = TrainingArguments(
            output_dir="./fine_tuned_llama_healthcare",
            per_device_train_batch_size=4,
            gradient_accumulation_steps=8,
            eval_strategy="no",
            save_strategy="epoch",
            save_total_limit=2,
            num_train_epochs=5,
            learning_rate=2e-5,
            weight_decay=0.01,
            logging_dir="./logs",
            logging_steps=10,
            bf16=True,
            gradient_checkpointing=True,
            optim="adamw_torch",
            warmup_steps=100,
        )

        def custom_data_collator(features):
            return {
                "input_ids": torch.stack([f["input_ids"] for f in features]),
                "attention_mask": torch.stack([f["attention_mask"] for f in features]),
                "labels": torch.stack([f["labels"] for f in features]),
            }

        trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=tokenized_dataset,
            data_collator=custom_data_collator,
        )
        trainer.train()
        model.save_pretrained("./fine_tuned_llama_healthcare")
        tokenizer.save_pretrained("./fine_tuned_llama_healthcare")
        return "Training completed! Model saved to ./fine_tuned_llama_healthcare"

    except Exception as e:
        return f"Error: {str(e)}. Please check file format, dependencies, or the LLama token."

# Gradio UI
with gr.Blocks(title="Healthcare Fraud Detection Fine-Tuning") as demo:
    gr.Markdown("# Fine-Tune LLaMA 2 for Healthcare Fraud Analysis")
    gr.Markdown("Upload PDFs (e.g., care logs, medication records) or a JSON file with training pairs.")
    file_input = gr.File(label="Upload Files (PDF/JSON)", file_count="multiple")
    train_button = gr.Button("Start Fine-Tuning")
    output = gr.Textbox(label="Training Status", lines=5)
    train_button.click(fn=train_ui, inputs=file_input, outputs=output)

# Launch the Gradio app
demo.launch()