import gradio as gr
import base64
import json
import os
import shutil
import uuid
import glob
from huggingface_hub import CommitScheduler, HfApi, snapshot_download
from pathlib import Path

api = HfApi(token=os.environ["HF_TOKEN"])


# Download existing data from hub
def sync_with_hub():
    """
    Synchronize local data with the hub by downloading latest dataset
    """
    print("Starting sync with hub...")
    data_dir = Path("./data")
    if data_dir.exists():
        # Backup existing data
        backup_dir = Path("./data_backup")
        if backup_dir.exists():
            shutil.rmtree(backup_dir)
        shutil.copytree(data_dir, backup_dir)

    # Download latest data from hub
    repo_path = snapshot_download(
        repo_id="taesiri/zb_dataset_storage3", repo_type="dataset", local_dir="hub_data"
    )

    # Merge hub data with local data
    hub_data_dir = Path(repo_path) / "data"
    if hub_data_dir.exists():
        # Create data dir if it doesn't exist
        data_dir.mkdir(exist_ok=True)

        # Copy files from hub
        for item in hub_data_dir.glob("*"):
            if item.is_dir():
                dest = data_dir / item.name
                if not dest.exists():  # Only copy if doesn't exist locally
                    shutil.copytree(item, dest)

    # Clean up downloaded repo
    if Path("hub_data").exists():
        shutil.rmtree("hub_data")
    print("Finished syncing with hub!")


scheduler = CommitScheduler(
    repo_id="taesiri/zb_dataset_storage3",
    repo_type="dataset",
    folder_path="./data",
    path_in_repo="data",
    every=1,
)


def load_existing_questions():
    """
    Load all existing questions from the data directory
    Returns a list of tuples (question_id, question_preview)
    """
    questions = []
    data_dir = "./data"
    if not os.path.exists(data_dir):
        return questions

    for question_dir in glob.glob(os.path.join(data_dir, "*")):
        if os.path.isdir(question_dir):
            json_path = os.path.join(question_dir, "question.json")
            if os.path.exists(json_path):
                try:
                    with open(json_path, "r", encoding="utf-8") as f:
                        data = json.loads(f.read().strip())
                        question_id = os.path.basename(question_dir)
                        preview = (
                            f"{data['question'][:100]}..."
                            if len(data["question"]) > 100
                            else data["question"]
                        )
                        questions.append((question_id, f"{question_id}: {preview}"))
                except:
                    continue

    return sorted(questions, key=lambda x: x[1])


def load_question_data(question_id):
    """
    Load a specific question's data
    Returns a tuple of all form fields
    """
    if not question_id:
        return [None] * 26 + [None]  # Changed from gr.State(value=None) to just None

    # Extract the ID part before the colon from the dropdown selection
    question_id = (
        question_id.split(":")[0].strip() if ":" in question_id else question_id
    )

    json_path = os.path.join("./data", question_id, "question.json")
    if not os.path.exists(json_path):
        print(f"Question file not found: {json_path}")
        return [None] * 26 + [None]

    try:
        with open(json_path, "r", encoding="utf-8") as f:
            data = json.loads(f.read().strip())

        # Load images
        def load_image(image_path):
            if not image_path:
                return None
            full_path = os.path.join(
                "./data", question_id, os.path.basename(image_path)
            )
            return full_path if os.path.exists(full_path) else None

        question_images = data.get("question_images", [])
        rationale_images = data.get("rationale_images", [])

        # Convert authorship_interest to boolean if it's a string
        authorship = data["author_info"].get("authorship_interest", False)
        if isinstance(authorship, str):
            authorship = authorship.lower() == "true"

        return [
            data["author_info"]["name"],
            data["author_info"]["email_address"],
            data["author_info"]["institution"],
            data["author_info"].get("openreview_profile", ""),
            authorship,
            (
                ",".join(data["question_categories"])
                if isinstance(data["question_categories"], list)
                else data["question_categories"]
            ),
            data.get("subquestions_1_text", "N/A"),
            data.get("subquestions_1_answer", "N/A"),
            data.get("subquestions_2_text", "N/A"),
            data.get("subquestions_2_answer", "N/A"),
            data.get("subquestions_3_text", "N/A"),
            data.get("subquestions_3_answer", "N/A"),
            data.get("subquestions_4_text", "N/A"),
            data.get("subquestions_4_answer", "N/A"),
            data.get("subquestions_5_text", "N/A"),
            data.get("subquestions_5_answer", "N/A"),
            data["question"],
            data["final_answer"],
            data.get("rationale_text", ""),
            data["image_attribution"],
            load_image(question_images[0] if question_images else None),
            load_image(question_images[1] if len(question_images) > 1 else None),
            load_image(question_images[2] if len(question_images) > 2 else None),
            load_image(question_images[3] if len(question_images) > 3 else None),
            load_image(rationale_images[0] if rationale_images else None),
            load_image(rationale_images[1] if len(rationale_images) > 1 else None),
            question_id,  # Changed from gr.State(value=question_id) to just question_id
        ]
    except Exception as e:
        print(f"Error loading question {question_id}: {str(e)}")
        return [None] * 26 + [None]


def generate_json_files(
    name,
    email_address,
    institution,
    openreview_profile,
    authorship_interest,
    question_categories,
    subquestion_1_text,
    subquestion_1_answer,
    subquestion_2_text,
    subquestion_2_answer,
    subquestion_3_text,
    subquestion_3_answer,
    subquestion_4_text,
    subquestion_4_answer,
    subquestion_5_text,
    subquestion_5_answer,
    question,
    final_answer,
    rationale_text,
    image_attribution,
    image1,
    image2,
    image3,
    image4,
    rationale_image1,
    rationale_image2,
    existing_id=None,  # New parameter for updating existing questions
):
    """
    For each request:
      1) Create a unique folder under ./data/ (or use existing if updating)
      2) Copy uploaded images (question + rationale) into that folder
      3) Produce JSON file with question data
      4) Return path to the JSON file
    """

    # Use existing ID if updating, otherwise generate new one
    request_id = existing_id if existing_id else str(uuid.uuid4())

    # Create parent data folder if it doesn't exist
    parent_data_folder = "./data"
    os.makedirs(parent_data_folder, exist_ok=True)

    # Create or clean request folder
    request_folder = os.path.join(parent_data_folder, request_id)
    if os.path.exists(request_folder):
        # If updating, remove old image files but only if new images are provided
        for f in glob.glob(os.path.join(request_folder, "*.png")):
            # Only remove if we have a new image to replace it
            filename = os.path.basename(f)
            if (
                ("question_image_1" in filename and image1)
                or ("question_image_2" in filename and image2)
                or ("question_image_3" in filename and image3)
                or ("question_image_4" in filename and image4)
                or ("rationale_image_1" in filename and rationale_image1)
                or ("rationale_image_2" in filename and rationale_image2)
            ):
                os.remove(f)
    else:
        os.makedirs(request_folder)

    # Convert None strings
    def safe_str(val):
        return val if val is not None else ""

    name = safe_str(name)
    email_address = safe_str(email_address)
    institution = safe_str(institution)
    openreview_profile = safe_str(openreview_profile)
    authorship_interest = safe_str(authorship_interest)
    image_attribution = safe_str(image_attribution)
    # Convert question_categories to list
    question_categories = (
        [cat.strip() for cat in safe_str(question_categories).split(",")]
        if question_categories
        else []
    )
    subquestion_1_text = safe_str(subquestion_1_text)
    subquestion_1_answer = safe_str(subquestion_1_answer)
    subquestion_2_text = safe_str(subquestion_2_text)
    subquestion_2_answer = safe_str(subquestion_2_answer)
    subquestion_3_text = safe_str(subquestion_3_text)
    subquestion_3_answer = safe_str(subquestion_3_answer)
    subquestion_4_text = safe_str(subquestion_4_text)
    subquestion_4_answer = safe_str(subquestion_4_answer)
    subquestion_5_text = safe_str(subquestion_5_text)
    subquestion_5_answer = safe_str(subquestion_5_answer)
    question = safe_str(question)
    final_answer = safe_str(final_answer)
    rationale_text = safe_str(rationale_text)

    # Collect image-like fields so we can process them in one loop
    all_images = [
        ("question_image_1", image1),
        ("question_image_2", image2),
        ("question_image_3", image3),
        ("question_image_4", image4),
        ("rationale_image_1", rationale_image1),
        ("rationale_image_2", rationale_image2),
    ]

    # If updating, load existing images that haven't been replaced
    if existing_id:
        json_path = os.path.join(parent_data_folder, existing_id, "question.json")
        if os.path.exists(json_path):
            try:
                with open(json_path, "r", encoding="utf-8") as f:
                    existing_data = json.loads(f.read().strip())
                    existing_question_images = existing_data.get("question_images", [])
                    existing_rationale_images = existing_data.get(
                        "rationale_images", []
                    )

                    # Keep existing images if no new ones provided
                    if not image1 and existing_question_images:
                        all_images[0] = (
                            "question_image_1",
                            existing_question_images[0],
                        )
                    if not image2 and len(existing_question_images) > 1:
                        all_images[1] = (
                            "question_image_2",
                            existing_question_images[1],
                        )
                    if not image3 and len(existing_question_images) > 2:
                        all_images[2] = (
                            "question_image_3",
                            existing_question_images[2],
                        )
                    if not image4 and len(existing_question_images) > 3:
                        all_images[3] = (
                            "question_image_4",
                            existing_question_images[3],
                        )
                    if not rationale_image1 and existing_rationale_images:
                        all_images[4] = (
                            "rationale_image_1",
                            existing_rationale_images[0],
                        )
                    if not rationale_image2 and len(existing_rationale_images) > 1:
                        all_images[5] = (
                            "rationale_image_2",
                            existing_rationale_images[1],
                        )
            except:
                pass

    files_list = []
    for idx, (img_label, img_obj) in enumerate(all_images):
        if img_obj is not None:
            temp_path = os.path.join(request_folder, f"{img_label}.png")
            if isinstance(img_obj, str):
                # If image is a file path
                if os.path.exists(img_obj):
                    if (
                        img_obj != temp_path
                    ):  # Only copy if source and destination are different
                        shutil.copy2(img_obj, temp_path)
                    files_list.append((img_label, temp_path))
            else:
                # If image is a numpy array
                gr.processing_utils.save_image(img_obj, temp_path)
                files_list.append((img_label, temp_path))

    # Build user content in two flavors: local file paths vs base64
    # We'll store text fields as simple dictionaries, and then images separately.
    content_list_urls = [
        {"type": "field", "label": "name", "value": name},
        {"type": "field", "label": "email_address", "value": email_address},
        {"type": "field", "label": "institution", "value": institution},
        {"type": "field", "label": "openreview_profile", "value": openreview_profile},
        {"type": "field", "label": "authorship_interest", "value": authorship_interest},
        {"type": "field", "label": "question_categories", "value": question_categories},
        {"type": "field", "label": "image_attribution", "value": image_attribution},
        {"type": "field", "label": "subquestion_1_text", "value": subquestion_1_text},
        {
            "type": "field",
            "label": "subquestion_1_answer",
            "value": subquestion_1_answer,
        },
        {"type": "field", "label": "subquestion_2_text", "value": subquestion_2_text},
        {
            "type": "field",
            "label": "subquestion_2_answer",
            "value": subquestion_2_answer,
        },
        {"type": "field", "label": "subquestion_3_text", "value": subquestion_3_text},
        {
            "type": "field",
            "label": "subquestion_3_answer",
            "value": subquestion_3_answer,
        },
        {"type": "field", "label": "subquestion_4_text", "value": subquestion_4_text},
        {
            "type": "field",
            "label": "subquestion_4_answer",
            "value": subquestion_4_answer,
        },
        {"type": "field", "label": "subquestion_5_text", "value": subquestion_5_text},
        {
            "type": "field",
            "label": "subquestion_5_answer",
            "value": subquestion_5_answer,
        },
        {"type": "field", "label": "question", "value": question},
        {"type": "field", "label": "final_answer", "value": final_answer},
        {"type": "field", "label": "rationale_text", "value": rationale_text},
    ]

    # Append image references
    for img_label, file_path in files_list:
        # 1) Local path (URL) version
        rel_path = os.path.join(".", os.path.basename(file_path))
        content_list_urls.append(
            {
                "type": "image_url",
                "label": img_label,
                "image_url": {"url": {"data:image/png;path": rel_path}},
            }
        )

    # Build the final JSON structures for each approach
    # A) URLs JSON
    item_urls = {
        "custom_id": f"question___{request_id}",
        # Metadata at top level
        "author_info": {
            "name": name,
            "email_address": email_address,
            "institution": institution,
            "openreview_profile": openreview_profile,
            "authorship_interest": authorship_interest,
        },
        "question_categories": question_categories,
        "image_attribution": image_attribution,
        "question": question,
        "question_images": [
            item["image_url"]["url"]["data:image/png;path"]
            for item in content_list_urls
            if item.get("type") == "image_url"
            and "question_image" in item.get("label", "")
        ],
        "final_answer": final_answer,
        "rationale_text": rationale_text,
        "rationale_images": [
            item["image_url"]["url"]["data:image/png;path"]
            for item in content_list_urls
            if item.get("type") == "image_url"
            and "rationale_image" in item.get("label", "")
        ],
        "subquestions_1_text": subquestion_1_text,
        "subquestions_1_answer": subquestion_1_answer,
        "subquestions_2_text": subquestion_2_text,
        "subquestions_2_answer": subquestion_2_answer,
        "subquestions_3_text": subquestion_3_text,
        "subquestions_3_answer": subquestion_3_answer,
        "subquestions_4_text": subquestion_4_text,
        "subquestions_4_answer": subquestion_4_answer,
        "subquestions_5_text": subquestion_5_text,
        "subquestions_5_answer": subquestion_5_answer,
    }

    # Convert each to JSON line format
    urls_json_line = json.dumps(item_urls, ensure_ascii=False)

    # 3) Write out JSON file in request_folder
    urls_jsonl_path = os.path.join(request_folder, "question.json")

    with open(urls_jsonl_path, "w", encoding="utf-8") as f:
        f.write(urls_json_line + "\n")

    return urls_jsonl_path


# Build the Gradio app
with gr.Blocks() as demo:
    gr.Markdown("# Dataset Builder")
    # Add a global state variable at the top level
    loaded_question_id = gr.State()

    with gr.Accordion("Instructions", open=True):
        gr.HTML(
            """
            <h3>Instructions:</h3>
            <p>Welcome to the Hugging Face space for collecting questions for new benchmark datasets.</p>
            
            <table style="width:100%; border-collapse: collapse; margin: 10px 0;">
                <tr>
                    <th style="width:50%; background-color: #3366f0; padding: 8px; text-align: left; border: 1px solid #ddd;">
                        Required Fields
                    </th>
                    <th style="width:50%; background-color: #3366f0; padding: 8px; text-align: left; border: 1px solid #ddd;">
                        Optional Fields
                    </th>
                </tr>
                <tr>
                    <td style="vertical-align: top; padding: 8px; border: 1px solid #ddd;">
                        <ul style="margin: 0;">
                            <li>Author Information</li>
                            <li>At least <b>one question image</b></li>
                            <li>The <b>question text</b></li>
                            <li>The <b>final answer</b></li>
                            <li><b>Sub-questions</b> with their answers (write 'N/A' if breaking into steps is not reasonable - please use sparingly)</li>
                        </ul>
                    </td>
                    <td style="vertical-align: top; padding: 8px; border: 1px solid #ddd;">
                        <ul style="margin: 0;">
                            <li>Up to three additional question images</li>
                            <li>Supporting images for your answer</li>
                            <li><b>Rationale text</b> to explain your reasoning</li>
                        </ul>
                    </td>
                </tr>
            </table>

            <h3>Question Criteria:</h3>
            <ul>
                <li>Make questions as challenging as possible. At a minimum, obtaining the correct answer needs to be beyond the capabilities of state-of-the-art large multimodal models.</li>
                <li>Structure your questions to require multiple steps/sub-questions to reach the final answer (e.g., identifying/counting specific objects in the image or requiring a particular piece of knowledge) — this will likely enable better differentiation of model performance.</li>
                <li>Include images/questions that are not copyright-restricted.</li>
            </ul>

            <h3>Authorship Opportunity:</h3>
            <p>Would you like to be included as an author on our paper? Authorship is offered to anyone submitting 5 or more difficult questions!</p>

            <p>While not all fields are mandatory, providing additional context through optional fields will help create a more comprehensive dataset. After submitting a question, you can clear up the form to submit another one.</p>
            """
        )
    gr.Markdown("## Author Information")
    with gr.Row():
        name_input = gr.Textbox(label="Name", lines=1)
        email_address_input = gr.Textbox(label="Email Address", lines=1)
        institution_input = gr.Textbox(
            label="Institution or 'Independent'",
            lines=1,
            placeholder="e.g. MIT, Google, Independent, etc.",
        )
        openreview_profile_input = gr.Textbox(
            label="OpenReview Profile Name",
            lines=1,
            placeholder="Your OpenReview username or profile name",
        )

    # Add authorship checkbox
    authorship_input = gr.Checkbox(
        label="Would you like to be considered for authorship? (Requires submitting 5+ difficult questions)",
        value=False,
    )

    gr.Markdown("## Question Information")

    # image
    gr.Markdown("### Images Attribution")
    image_attribution_input = gr.Textbox(
        label="Images Attribution",
        lines=1,
        placeholder="Include attribution information for the images used in this question (or 'Own' if you created/took them)",
    )

    # Question Images - Individual Tabs
    with gr.Tabs():
        with gr.Tab("Image 1"):
            image1 = gr.Image(label="Question Image 1", type="filepath")
        with gr.Tab("Image 2 (Optional)"):
            image2 = gr.Image(label="Question Image 2", type="filepath")
        with gr.Tab("Image 3 (Optional)"):
            image3 = gr.Image(label="Question Image 3", type="filepath")
        with gr.Tab("Image 4 (Optional)"):
            image4 = gr.Image(label="Question Image 4", type="filepath")

    question_input = gr.Textbox(
        label="Question", lines=15, placeholder="Type your question here..."
    )

    question_categories_input = gr.Textbox(
        label="Question Categories",
        lines=1,
        placeholder="Comma-separated tags, e.g. math, geometry",
    )

    # Answer Section
    gr.Markdown("## Answer ")

    final_answer_input = gr.Textbox(
        label="Final Answer",
        lines=1,
        placeholder="Enter the short/concise final answer...",
    )

    rationale_text_input = gr.Textbox(
        label="Rationale Text",
        lines=5,
        placeholder="Enter the reasoning or explanation for the answer...",
    )

    # Rationale Images - Individual Tabs
    with gr.Tabs():
        with gr.Tab("Rationale 1 (Optional)"):
            rationale_image1 = gr.Image(label="Rationale Image 1", type="filepath")
        with gr.Tab("Rationale 2 (Optional)"):
            rationale_image2 = gr.Image(label="Rationale Image 2", type="filepath")

    # Subquestions Section
    gr.Markdown("## Subquestions")
    with gr.Row():
        subquestion_1_text_input = gr.Textbox(
            label="Subquestion 1 Text",
            lines=2,
            placeholder="First sub-question...",
            value="N/A",
        )
        subquestion_1_answer_input = gr.Textbox(
            label="Subquestion 1 Answer",
            lines=2,
            placeholder="Answer to sub-question 1...",
            value="N/A",
        )

    with gr.Row():
        subquestion_2_text_input = gr.Textbox(
            label="Subquestion 2 Text",
            lines=2,
            placeholder="Second sub-question...",
            value="N/A",
        )
        subquestion_2_answer_input = gr.Textbox(
            label="Subquestion 2 Answer",
            lines=2,
            placeholder="Answer to sub-question 2...",
            value="N/A",
        )

    with gr.Row():
        subquestion_3_text_input = gr.Textbox(
            label="Subquestion 3 Text",
            lines=2,
            placeholder="Third sub-question...",
            value="N/A",
        )
        subquestion_3_answer_input = gr.Textbox(
            label="Subquestion 3 Answer",
            lines=2,
            placeholder="Answer to sub-question 3...",
            value="N/A",
        )

    with gr.Row():
        subquestion_4_text_input = gr.Textbox(
            label="Subquestion 4 Text",
            lines=2,
            placeholder="Fourth sub-question...",
            value="N/A",
        )
        subquestion_4_answer_input = gr.Textbox(
            label="Subquestion 4 Answer",
            lines=2,
            placeholder="Answer to sub-question 4...",
            value="N/A",
        )

    with gr.Row():
        subquestion_5_text_input = gr.Textbox(
            label="Subquestion 5 Text",
            lines=2,
            placeholder="Fifth sub-question...",
            value="N/A",
        )
        subquestion_5_answer_input = gr.Textbox(
            label="Subquestion 5 Answer",
            lines=2,
            placeholder="Answer to sub-question 5...",
            value="N/A",
        )

    with gr.Row():
        submit_button = gr.Button("Submit")
        clear_button = gr.Button("Clear Form")

    with gr.Row():
        output_file_urls = gr.File(
            label="Download URLs JSON", interactive=False, visible=False
        )
        output_file_base64 = gr.File(
            label="Download Base64 JSON", interactive=False, visible=False
        )

    with gr.Accordion("Load Existing Question", open=False):
        gr.Markdown("## Load Existing Question")

        with gr.Row():
            existing_questions = gr.Dropdown(
                label="Load Existing Question",
                choices=load_existing_questions(),
                type="value",
                allow_custom_value=False,
            )
            refresh_button = gr.Button("🔄 Refresh")
            load_button = gr.Button("Load Selected Question")

    def refresh_questions():
        return gr.Dropdown(choices=load_existing_questions())

    refresh_button.click(fn=refresh_questions, inputs=[], outputs=[existing_questions])

    # Load button functionality
    load_button.click(
        fn=load_question_data,
        inputs=[existing_questions],
        outputs=[
            name_input,
            email_address_input,
            institution_input,
            openreview_profile_input,
            authorship_input,
            question_categories_input,
            subquestion_1_text_input,
            subquestion_1_answer_input,
            subquestion_2_text_input,
            subquestion_2_answer_input,
            subquestion_3_text_input,
            subquestion_3_answer_input,
            subquestion_4_text_input,
            subquestion_4_answer_input,
            subquestion_5_text_input,
            subquestion_5_answer_input,
            question_input,
            final_answer_input,
            rationale_text_input,
            image_attribution_input,
            image1,
            image2,
            image3,
            image4,
            rationale_image1,
            rationale_image2,
            loaded_question_id,
        ],
    )

    # Modify validate_and_generate to handle updates
    def validate_and_generate(
        nm,
        em,
        inst,
        orp,
        auth,
        qcats,
        sq1t,
        sq1a,
        sq2t,
        sq2a,
        sq3t,
        sq3a,
        sq4t,
        sq4a,
        sq5t,
        sq5a,
        q,
        fa,
        rt,
        ia,
        i1,
        i2,
        i3,
        i4,
        ri1,
        ri2,
        stored_question_id,  # Add this parameter
    ):
        # Validation code remains the same
        missing_fields = []
        if not nm or not nm.strip():
            missing_fields.append("Name")
        if not em or not em.strip():
            missing_fields.append("Email Address")
        if not inst or not inst.strip():
            missing_fields.append("Institution")
        if not q or not q.strip():
            missing_fields.append("Question")
        if not fa or not fa.strip():
            missing_fields.append("Final Answer")
        if not i1:
            missing_fields.append("First Question Image")
        if not ia or not ia.strip():
            missing_fields.append("Image Attribution")
        if not sq1t or not sq1t.strip() or not sq1a or not sq1a.strip():
            missing_fields.append("First Sub-question and Answer")
        if not sq2t or not sq2t.strip() or not sq2a or not sq2a.strip():
            missing_fields.append("Second Sub-question and Answer")
        if not sq3t or not sq3t.strip() or not sq3a or not sq3a.strip():
            missing_fields.append("Third Sub-question and Answer")
        if not sq4t or not sq4t.strip() or not sq4a or not sq4a.strip():
            missing_fields.append("Fourth Sub-question and Answer")
        if not sq5t or not sq5t.strip() or not sq5a or not sq5a.strip():
            missing_fields.append("Fifth Sub-question and Answer")

        if missing_fields:
            warning_msg = f"Required fields missing: {', '.join(missing_fields)} ⛔️"
            gr.Warning(warning_msg, duration=5)
            return gr.Button(interactive=True), gr.Dropdown(
                choices=load_existing_questions()
            )

        # Use the stored ID instead of extracting from dropdown
        existing_id = stored_question_id if stored_question_id else None

        results = generate_json_files(
            nm,
            em,
            inst,
            orp,
            auth,
            qcats,
            sq1t,
            sq1a,
            sq2t,
            sq2a,
            sq3t,
            sq3a,
            sq4t,
            sq4a,
            sq5t,
            sq5a,
            q,
            fa,
            rt,
            ia,
            i1,
            i2,
            i3,
            i4,
            ri1,
            ri2,
            existing_id,
        )

        action = "updated" if existing_id else "created"
        gr.Info(
            f"Dataset item {action} successfully! 🎉 Clear the form to submit a new one"
        )

        return gr.update(interactive=False), gr.Dropdown(
            choices=load_existing_questions()
        )

    # Update submit button click handler to match inputs/outputs correctly
    submit_button.click(
        fn=validate_and_generate,
        inputs=[
            name_input,
            email_address_input,
            institution_input,
            openreview_profile_input,
            authorship_input,
            question_categories_input,
            subquestion_1_text_input,
            subquestion_1_answer_input,
            subquestion_2_text_input,
            subquestion_2_answer_input,
            subquestion_3_text_input,
            subquestion_3_answer_input,
            subquestion_4_text_input,
            subquestion_4_answer_input,
            subquestion_5_text_input,
            subquestion_5_answer_input,
            question_input,
            final_answer_input,
            rationale_text_input,
            image_attribution_input,
            image1,
            image2,
            image3,
            image4,
            rationale_image1,
            rationale_image2,
            loaded_question_id,
        ],
        outputs=[submit_button, existing_questions],
    )

    # Fix the clear_form_fields function
    def clear_form_fields(name, email, inst, openreview, authorship, *args):
        outputs = [
            name,  # Preserve name
            email,  # Preserve email
            inst,  # Preserve institution
            openreview,  # Preserve openreview
            authorship,  # Preserve authorship interest
            gr.update(value=""),  # Clear question categories
            gr.update(value="N/A"),  # Reset subquestion 1 text to N/A
            gr.update(value="N/A"),  # Reset subquestion 1 answer to N/A
            gr.update(value="N/A"),  # Reset subquestion 2 text to N/A
            gr.update(value="N/A"),  # Reset subquestion 2 answer to N/A
            gr.update(value="N/A"),  # Reset subquestion 3 text to N/A
            gr.update(value="N/A"),  # Reset subquestion 3 answer to N/A
            gr.update(value="N/A"),  # Reset subquestion 4 text to N/A
            gr.update(value="N/A"),  # Reset subquestion 4 answer to N/A
            gr.update(value="N/A"),  # Reset subquestion 5 text to N/A
            gr.update(value="N/A"),  # Reset subquestion 5 answer to N/A
            gr.update(value=""),  # Clear question
            gr.update(value=""),  # Clear final answer
            gr.update(value=""),  # Clear rationale text
            gr.update(value=""),  # Clear image attribution
            None,  # Clear image1
            None,  # Clear image2
            None,  # Clear image3
            None,  # Clear image4
            None,  # Clear rationale image1
            None,  # Clear rationale image2
            None,  # Clear output file urls
            gr.Button(interactive=True),  # Re-enable submit button
            gr.update(choices=load_existing_questions()),  # Update dropdown
            None,  # Changed from gr.State(value=None) to just None
        ]
        gr.Info("Form cleared! Ready for new submission 🔄")
        return outputs

    # Update the clear button click handler
    clear_button.click(
        fn=clear_form_fields,
        inputs=[
            name_input,
            email_address_input,
            institution_input,
            openreview_profile_input,
            authorship_input,
        ],
        outputs=[
            name_input,
            email_address_input,
            institution_input,
            openreview_profile_input,
            authorship_input,
            question_categories_input,
            subquestion_1_text_input,
            subquestion_1_answer_input,
            subquestion_2_text_input,
            subquestion_2_answer_input,
            subquestion_3_text_input,
            subquestion_3_answer_input,
            subquestion_4_text_input,
            subquestion_4_answer_input,
            subquestion_5_text_input,
            subquestion_5_answer_input,
            question_input,
            final_answer_input,
            rationale_text_input,
            image_attribution_input,
            image1,
            image2,
            image3,
            image4,
            rationale_image1,
            rationale_image2,
            output_file_urls,
            submit_button,
            existing_questions,
            loaded_question_id,
        ],
    )

if __name__ == "__main__":
    print("Initializing app...")
    sync_with_hub()  # Sync before launching the app
    print("Starting Gradio interface...")
    demo.launch()