import subprocess # 🥲
import os
import time
import torch
import numpy as np
import gradio as gr
import spaces
import re
import json
from datetime import datetime
from transformers import AutoModelForCausalLM, AutoTokenizer
from duckduckgo_search import DDGS
from pydantic import BaseModel
# ----------------------- Setup & Dependency Installation ----------------------- #
try:
subprocess.run(['git', 'lfs', 'install'], check=True)
if not os.path.exists('Kokoro-82M'):
subprocess.run(['git', 'clone', 'https://huggingface.co/hexgrad/Kokoro-82M'], check=True)
try:
subprocess.run(['apt-get', 'update'], check=True)
subprocess.run(['apt-get', 'install', '-y', 'espeak'], check=True)
except subprocess.CalledProcessError:
print("Warning: Could not install espeak. Trying espeak-ng...")
try:
subprocess.run(['apt-get', 'install', '-y', 'espeak-ng'], check=True)
except subprocess.CalledProcessError:
print("Warning: Could not install espeak or espeak-ng. TTS functionality may be limited.")
except Exception as e:
print(f"Warning: Initial setup error: {str(e)}")
print("Continuing with limited functionality...")
# ----------------------- Global Variables ----------------------- #
VOICE_CHOICES = {
'🇺🇸 Female (Default)': 'af',
'🇺🇸 Bella': 'af_bella',
'🇺🇸 Sarah': 'af_sarah',
'🇺🇸 Nicole': 'af_nicole'
}
TTS_ENABLED = False # 기본적으로 TTS 모듈 초기화 실패 시 비활성화
# ----------------------- Model and Tokenizer Initialization ----------------------- #
model_name = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
def init_models():
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
offload_folder="offload",
low_cpu_mem_usage=True,
torch_dtype=torch.float16
)
return model
# ----------------------- Kokoro TTS Initialization ----------------------- #
try:
import sys
sys.path.append('Kokoro-82M')
from models import build_model
from kokoro import generate
TTS_ENABLED = True
except Exception as e:
print(f"Warning: Could not initialize Kokoro TTS: {str(e)}")
TTS_ENABLED = False
# ----------------------- Web Search Functions ----------------------- #
def get_web_results(query, max_results=5):
try:
with DDGS() as ddgs:
results = list(ddgs.text(query, max_results=max_results))
return [{
"title": result.get("title", ""),
"snippet": result["body"],
"url": result["href"],
"date": result.get("published", "")
} for result in results]
except Exception as e:
return []
def format_prompt(query, context):
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
context_lines = '\n'.join([f'- [{res["title"]}]: {res["snippet"]}' for res in context])
return f"""You are an intelligent search assistant. Answer the user's query using the provided web context.
Current Time: {current_time}
Important: For election-related queries, please distinguish clearly between different election years and types (presidential vs. non-presidential). Only use information from the provided web context.
Query: {query}
Web Context:
{context_lines}
Provide a detailed answer in markdown format. Include relevant information from sources and cite them using [1], [2], etc. If the query is about elections, clearly specify which year and type of election you're discussing.
Answer:"""
def format_sources(web_results):
if not web_results:
return "
No sources available
"
sources_html = ""
for i, res in enumerate(web_results, 1):
title = res["title"] or "Source"
date = f"
{res['date']}" if res['date'] else ""
sources_html += f"""
[{i}]
{title}
{date}
{res['snippet'][:150]}...
"""
sources_html += "
"
return sources_html
# ----------------------- Answer Generation ----------------------- #
@spaces.GPU(duration=30)
def generate_answer(prompt):
model = init_models()
inputs = tokenizer(
prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512,
return_attention_mask=True
).to(model.device)
outputs = model.generate(
inputs.input_ids,
attention_mask=inputs.attention_mask,
max_new_tokens=256,
temperature=0.7,
top_p=0.95,
pad_token_id=tokenizer.eos_token_id,
do_sample=True,
early_stopping=True
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
@spaces.GPU(duration=60)
def generate_speech_with_gpu(text, voice_name='af'):
try:
device = 'cuda'
TTS_MODEL = build_model('Kokoro-82M/kokoro-v0_19.pth', device)
VOICEPACK = torch.load(f'Kokoro-82M/voices/{voice_name}.pt', weights_only=True).to(device)
clean_text = ' '.join([line for line in text.split('\n') if not line.startswith('#')])
clean_text = clean_text.replace('[', '').replace(']', '').replace('*', '')
max_chars = 1000
if len(clean_text) > max_chars:
sentences = clean_text.split('.')
chunks = []
current_chunk = ""
for sentence in sentences:
if len(current_chunk) + len(sentence) < max_chars:
current_chunk += sentence + "."
else:
if current_chunk:
chunks.append(current_chunk)
current_chunk = sentence + "."
if current_chunk:
chunks.append(current_chunk)
else:
chunks = [clean_text]
audio_chunks = []
for chunk in chunks:
if chunk.strip():
chunk_audio, _ = generate(TTS_MODEL, chunk.strip(), VOICEPACK, lang='a')
if isinstance(chunk_audio, torch.Tensor):
chunk_audio = chunk_audio.cpu().numpy()
audio_chunks.append(chunk_audio)
if audio_chunks:
final_audio = np.concatenate(audio_chunks) if len(audio_chunks) > 1 else audio_chunks[0]
return (24000, final_audio)
return None
except Exception as e:
print(f"Error generating speech: {str(e)}")
import traceback
traceback.print_exc()
return None
def process_query(query, history, selected_voice='af'):
try:
if history is None:
history = []
web_results = get_web_results(query)
sources_html = format_sources(web_results)
current_history = history + [[query, "*Searching...*"]]
yield {
answer_output: gr.Markdown("*Searching & Thinking...*"),
sources_output: gr.HTML(sources_html),
search_btn: gr.Button("Searching...", interactive=False),
chat_history_display: current_history,
audio_output: None
}
prompt_text = format_prompt(query, web_results)
answer = generate_answer(prompt_text)
final_answer = answer.split("Answer:")[-1].strip()
if TTS_ENABLED:
try:
yield {
answer_output: gr.Markdown(final_answer),
sources_output: gr.HTML(sources_html),
search_btn: gr.Button("Generating audio...", interactive=False),
chat_history_display: history + [[query, final_answer]],
audio_output: None
}
audio = generate_speech_with_gpu(final_answer, selected_voice)
except Exception as e:
print(f"Error in speech generation: {str(e)}")
audio = None
else:
audio = None
updated_history = history + [[query, final_answer]]
yield {
answer_output: gr.Markdown(final_answer),
sources_output: gr.HTML(sources_html),
search_btn: gr.Button("Search", interactive=True),
chat_history_display: updated_history,
audio_output: audio if audio is not None else gr.Audio(value=None)
}
except Exception as e:
error_message = str(e)
if "GPU quota" in error_message:
error_message = "⚠️ GPU quota exceeded. Please try again later when the daily quota resets."
yield {
answer_output: gr.Markdown(f"Error: {error_message}"),
sources_output: gr.HTML(sources_html),
search_btn: gr.Button("Search", interactive=True),
chat_history_display: history + [[query, f"*Error: {error_message}*"]],
audio_output: None
}
# ----------------------- Custom CSS for Improved UI ----------------------- #
css = """
.gradio-container {
max-width: 1200px !important;
background-color: #1e1e1e !important;
padding: 20px;
border-radius: 12px;
}
#header {
text-align: center;
padding: 2rem 0;
background: #272727;
border-radius: 12px;
color: #ffffff;
margin-bottom: 2rem;
}
#header h1 {
font-size: 2.5rem;
margin-bottom: 0.5rem;
}
.search-container {
background: #272727;
border-radius: 12px;
padding: 1.5rem;
margin-bottom: 1rem;
}
.search-box {
padding: 1rem;
background: #333333;
border-radius: 8px;
margin-bottom: 1rem;
}
.search-box input[type="text"] {
background: #444444 !important;
border: 1px solid #555555 !important;
color: #ffffff !important;
border-radius: 8px !important;
}
.search-box input[type="text"]::placeholder {
color: #bbbbbb !important;
}
.search-box button {
background: #2563eb !important;
border: none !important;
}
.results-container {
background: #2c2c2c;
border-radius: 8px;
padding: 1.5rem;
margin-top: 1rem;
}
.answer-box {
background: #3a3a3a;
border-radius: 8px;
padding: 1.5rem;
color: #ffffff;
margin-bottom: 1rem;
}
.answer-box p {
color: #e0e0e0;
line-height: 1.6;
}
.sources-container {
margin-top: 1rem;
background: #2c2c2c;
border-radius: 8px;
padding: 1rem;
}
.source-item {
display: flex;
padding: 12px;
margin: 8px 0;
background: #3a3a3a;
border-radius: 8px;
transition: all 0.2s;
}
.source-item:hover {
background: #4a4a4a;
}
.source-number {
font-weight: bold;
margin-right: 12px;
color: #60a5fa;
}
.source-content {
flex: 1;
}
.source-title {
color: #60a5fa;
font-weight: 500;
text-decoration: none;
display: block;
margin-bottom: 4px;
}
.source-date {
color: #bbbbbb;
font-size: 0.9em;
margin-left: 8px;
}
.source-snippet {
color: #e0e0e0;
font-size: 0.9em;
line-height: 1.4;
}
.chat-history {
max-height: 400px;
overflow-y: auto;
padding: 1rem;
background: #2c2c2c;
border-radius: 8px;
margin-top: 1rem;
}
.voice-selector {
margin-top: 1rem;
background: #333333;
border-radius: 8px;
padding: 0.5rem;
}
.voice-selector select {
background: #444444 !important;
color: #ffffff !important;
border: 1px solid #555555 !important;
}
footer {
text-align: center;
padding: 1rem 0;
font-size: 0.9em;
color: #bbbbbb;
}
"""
# ----------------------- Gradio Interface ----------------------- #
with gr.Blocks(title="AI Search Assistant", css=css) as demo:
chat_history = gr.State([])
# 'id' 인자 대신 'elem_id' 사용
with gr.Column(elem_id="header"):
gr.Markdown("# 🔍 AI Search Assistant")
gr.Markdown("### Powered by DeepSeek & Real-time Web Results with Voice")
with gr.Column(elem_classes="search-container"):
with gr.Row(elem_classes="search-box"):
search_input = gr.Textbox(
label="",
placeholder="Ask anything...",
scale=5,
container=False
)
search_btn = gr.Button("Search", variant="primary", scale=1)
voice_select = gr.Dropdown(
choices=list(VOICE_CHOICES.items()),
value='af',
label="Select Voice",
elem_classes="voice-selector"
)
with gr.Row(elem_classes="results-container"):
with gr.Column(scale=2):
with gr.Column(elem_classes="answer-box"):
answer_output = gr.Markdown()
audio_output = gr.Audio(label="Voice Response")
with gr.Accordion("Chat History", open=False, elem_classes="accordion"):
chat_history_display = gr.Chatbot(elem_classes="chat-history")
with gr.Column(scale=1):
with gr.Column():
gr.Markdown("### Sources")
sources_output = gr.HTML()
with gr.Row():
gr.Examples(
examples=[
"musk explores blockchain for doge",
"nvidia to launch new gaming card",
"What are the best practices for sustainable living?",
"How is climate change affecting ocean ecosystems?"
],
inputs=search_input,
label="Try these examples"
)
search_btn.click(
fn=process_query,
inputs=[search_input, chat_history, voice_select],
outputs=[answer_output, sources_output, search_btn, chat_history_display, audio_output]
)
search_input.submit(
fn=process_query,
inputs=[search_input, chat_history, voice_select],
outputs=[answer_output, sources_output, search_btn, chat_history_display, audio_output]
)
if __name__ == "__main__":
demo.launch(share=True)