File size: 6,641 Bytes
3a43332
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20ee71d
 
 
 
 
 
 
3a43332
 
 
 
 
 
 
 
 
 
 
20ee71d
3a43332
 
 
 
 
 
 
 
 
 
 
 
20ee71d
3a43332
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20ee71d
3a43332
 
 
 
 
 
20ee71d
3a43332
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20ee71d
3a43332
 
 
 
 
 
 
 
 
 
20ee71d
3a43332
 
0466efc
3a43332
 
 
 
20ee71d
3a43332
 
 
 
 
 
 
e33fbba
3a43332
 
e33fbba
 
 
 
 
 
 
 
3a43332
 
e33fbba
 
 
 
 
 
 
3a43332
 
 
e33fbba
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import os
from typing import Dict, List, Optional, Tuple

import gradio as gr
from huggingface_hub import InferenceClient
from tavily import TavilyClient

from config import (
    HTML_SYSTEM_PROMPT, GENERIC_SYSTEM_PROMPT, HTML_SYSTEM_PROMPT_WITH_SEARCH,
    GENERIC_SYSTEM_PROMPT_WITH_SEARCH, FollowUpSystemPrompt
)
from chat_processing import (
    history_to_messages, messages_to_history,
    remove_code_block, apply_search_replace_changes, send_to_sandbox,
    history_to_chatbot_messages, get_gradio_language
)
from file_processing import (  # file_processing.py
    extract_text_from_file, create_multimodal_message,
)
from web_extraction import extract_website_content, enhance_query_with_search

# HF Inference Client
HF_TOKEN = os.getenv('HF_TOKEN')

def get_inference_client(model_id):
    """Return an InferenceClient with provider based on model_id."""
    provider = "groq" if model_id == "moonshotai/Kimi-K2-Instruct" else "auto"
    return InferenceClient(
        provider=provider,
        api_key=HF_TOKEN,
        bill_to="huggingface"
    )

# Tavily Search Client
TAVILY_API_KEY = os.getenv('TAVILY_API_KEY')
tavily_client = None
if TAVILY_API_KEY:
    try:
        tavily_client = TavilyClient(api_key=TAVILY_API_KEY)
    except Exception as e:
        print(f"Failed to initialize Tavily client: {e}")
        tavily_client = None

async def generation_code(query: Optional[str], image: Optional[gr.Image], file: Optional[str], website_url: Optional[str], _setting: Dict[str, str], _history: Optional[List[Tuple[str, str]]], _current_model: Dict, enable_search: bool = False, language: str = "html", progress=gr.Progress(track_tqdm=True)):
    if query is None:
        query = ''
    if _history is None:
        _history = []

    # Check if there's existing HTML content in history to determine if this is a modification request
    has_existing_html = False
    if _history:
        # Check the last assistant message for HTML content
        last_assistant_msg = _history[-1][1] if len(_history) > 0 else ""
        if '<!DOCTYPE html>' in last_assistant_msg or '<html' in last_assistant_msg:
            has_existing_html = True
    progress(0, desc="Initializing...")

    # Choose system prompt based on context
    if has_existing_html:
        # Use follow-up prompt for modifying existing HTML
        system_prompt = FollowUpSystemPrompt
    else:
        # Use language-specific prompt
        if language == "html":
            system_prompt = HTML_SYSTEM_PROMPT_WITH_SEARCH if enable_search else HTML_SYSTEM_PROMPT
        else:
            system_prompt = GENERIC_SYSTEM_PROMPT_WITH_SEARCH.format(language=language) if enable_search else GENERIC_SYSTEM_PROMPT.format(language=language)

    messages = history_to_messages(_history, system_prompt)

    # Extract file text and append to query if file is present
    file_text = ""
    progress(0.1, desc="Processing file...")
    if file:
        file_text = extract_text_from_file(file)
        if file_text:
            file_text = file_text[:5000]  # Limit to 5000 chars for prompt size
            query = f"{query}\n\n[Reference file content below]\n{file_text}"

    progress(0.2, desc="Extracting website content...")
    # Extract website content and append to query if website URL is present
    website_text = ""
    if website_url and website_url.strip():
        website_text = extract_website_content(website_url.strip())
        if website_text and not website_text.startswith("Error"):
            website_text = website_text[:8000]  # Limit to 8000 chars for prompt size
            query = f"{query}\n\n[Website content to redesign below]\n{website_text}"
        elif website_text.startswith("Error"):
            # Provide helpful guidance when website extraction fails
            fallback_guidance = """
Since I couldn't extract the website content, please provide additional details about what you'd like to build:
1. What type of website is this? (e.g., e-commerce, blog, portfolio, dashboard)
2. What are the main features you want?
3. What's the target audience?
4. Any specific design preferences? (colors, style, layout)
This will help me create a better design for you."""
            query = f"{query}\n\n[Error extracting website: {website_text}]{fallback_guidance}"

    progress(0.4, desc="Performing web search...")
    # Enhance query with search if enabled
    enhanced_query = enhance_query_with_search(query, enable_search)

    # Use dynamic client based on selected model
    client = get_inference_client(_current_model["id"])

    if image is not None:
        messages.append(create_multimodal_message(enhanced_query, image))
    else:
        messages.append({'role': 'user', 'content': enhanced_query})
    progress(0.5, desc="Generating code with AI model...")
    try:
        completion = client.chat.completions.create(
            model=_current_model["id"], # Corrected this line
            messages=messages,
            stream=True,
            max_tokens=5000
        )
        progress(0.6, desc="Streaming response...")
        content = ""
        for chunk in completion:
            if chunk.choices[0].delta.content:
                content += chunk.choices[0].delta.content
                clean_code = remove_code_block(content)
                if has_existing_html:
                    # Fallback: If the model returns a full HTML file, use it directly
                    if not (clean_code.strip().startswith("<!DOCTYPE html>") or clean_code.strip().startswith("<html")):
                        last_html = _history[-1][1] if _history else ""
                        modified_html = apply_search_replace_changes(last_html, clean_code)
                        clean_code = remove_code_block(modified_html)

                yield (
                    gr.update(value=clean_code, language=get_gradio_language(language)),
                    _history,
                    send_to_sandbox(clean_code) if language == "html" else "<div style='padding:1em;color:#888;text-align:center;'>Preview is only available for HTML.</div>",
                    history_to_chatbot_messages(_history)
                )
        # Final update
        _history = messages_to_history(messages + [{'role': 'assistant', 'content': content}])
        final_code = remove_code_block(content)
        yield (
            final_code,
            _history,
            send_to_sandbox(final_code),
            history_to_chatbot_messages(_history),
        )

    except Exception as e:
        error_message = f"Error: {str(e)}"
        yield (error_message, _history, None, history_to_chatbot_messages(_history))