File size: 4,139 Bytes
7b3a105
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import io
import multiprocessing
import os
import time

import gradio as gr
import pandas as pd
from unstructured.partition.pdf import partition_pdf
import nltk
from distilabel.pipeline import Pipeline
from distilabel.llms import InferenceEndpointsLLM
from distilabel.steps import LoadDataFromDicts, KeepColumns
from distilabel.steps.tasks import TextGeneration

from personas import *  # Assuming this contains TextToPersona and other necessary definitions

nltk.download("punkt", quiet=True)

PROMPT_TEMPLATE = """\
Generate a single prompt the persona below might ask to an AI assistant:

{{ persona }}
"""

# Get HF_TOKEN from environment variable
HF_TOKEN = os.environ.get("HF_TOKEN")


def process_pdfs(pdf_files):
    all_data = []
    for pdf_file in pdf_files:
        elements = partition_pdf(pdf_file.name)

        full_text = ""
        for element in elements:
            full_text += element.text + "\n"

        all_data.append({"text": full_text.strip()})

    return all_data


def _run_pipeline(result_queue, pdf_files):
    data = process_pdfs(pdf_files)

    with Pipeline(name="personahub-fineweb-edu-text-to-persona") as pipeline:
        input_batch_size = 10

        data_loader = LoadDataFromDicts(data=data)

        llm = InferenceEndpointsLLM(
            model_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
            api_key=HF_TOKEN,
        )

        text_to_persona = TextToPersona(
            llm=llm,
            input_batch_size=input_batch_size,
        )

        text_gen = TextGeneration(
            llm=llm,
            system_prompt="You are an AI assistant expert at simulating user interactions.",
            template=PROMPT_TEMPLATE,
            columns="persona",
            output_mappings={"generation": "instruction"},
            num_generations=1,
        )

        response_gen = TextGeneration(
            llm=llm,
            system_prompt="You are an AI assistant expert in responding to tasks",
            output_mappings={"generation": "response"},
        )

        keep = KeepColumns(
            columns=["text", "persona", "model_name", "instruction", "response"],
            input_batch_size=input_batch_size,
        )

        (data_loader >> text_to_persona >> text_gen >> response_gen >> keep)

    distiset = pipeline.run(use_cache=False)
    result_queue.put(distiset)


def generate_dataset(pdf_files, progress=gr.Progress()):
    result_queue = multiprocessing.Queue()
    p = multiprocessing.Process(
        target=_run_pipeline,
        args=(result_queue, pdf_files),
    )

    try:
        p.start()
        total_steps = 100
        for step in range(total_steps):
            if not p.is_alive() or p._popen.poll() is not None:
                break
            progress(
                (step + 1) / total_steps,
                desc="Generating dataset. Don't close this window.",
            )
            time.sleep(2)  # Adjust this value based on your needs
        p.join()
    except Exception as e:
        raise gr.Error(f"An error occurred during dataset generation: {str(e)}")

    distiset = result_queue.get()
    df = distiset["default"]["train"].to_pandas()
    progress(1.0, desc="Dataset generation completed")
    return df


def gradio_interface(pdf_files):
    if HF_TOKEN is None:
        raise gr.Error(
            "HF_TOKEN environment variable is not set. Please set it and restart the application."
        )
    df = generate_dataset(pdf_files)
    return df


with gr.Blocks(title="MyPersonas Dataset Generator") as app:
    gr.Markdown("# MyPersonas Dataset Generator")
    gr.Markdown("Upload one or more PDFs to generate a persona based SFT dataset.")

    with gr.Row():
        pdf_files = gr.File(label="Upload PDFs", file_count="multiple")

    with gr.Row():
        generate_button = gr.Button("Generate Dataset")

    output_dataframe = gr.DataFrame(
        label="Generated Dataset",
        interactive=False,
        wrap=True,
    )

    generate_button.click(
        fn=gradio_interface,
        inputs=[pdf_files],
        outputs=[output_dataframe],
    )

if __name__ == "__main__":
    app.launch()