RAGOndevice / app.py
openfree's picture
Update app.py
55caecd verified
import os
# 1) Dynamo ์™„์ „ ๋น„ํ™œ์„ฑํ™”
os.environ["TORCH_DYNAMO_DISABLE"] = "1"
# 2) Triton์˜ cudagraphs ์ตœ์ ํ™” ๋น„ํ™œ์„ฑํ™”
os.environ["TRITON_DISABLE_CUDAGRAPHS"] = "1"
# (์˜ต์…˜) ๊ฒฝ๊ณ  ๋ฌด์‹œ ์„ค์ •
import warnings
warnings.filterwarnings("ignore", message="skipping cudagraphs due to mutated inputs")
warnings.filterwarnings("ignore", message="Not enough SMs to use max_autotune_gemm mode")
import torch
# TensorFloat32 ์—ฐ์‚ฐ ํ™œ์„ฑํ™” (์„ฑ๋Šฅ ์ตœ์ ํ™”)
torch.set_float32_matmul_precision('high')
import torch._inductor
torch._inductor.config.triton.cudagraphs = False
import torch._dynamo
# suppress_errors (์˜ค๋ฅ˜ ์‹œ eager๋กœ fallback)
torch._dynamo.config.suppress_errors = True
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
from datasets import load_dataset
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
import pandas as pd
import json
from datetime import datetime
import pyarrow.parquet as pq
import pypdf
import io
import platform
import subprocess
import pytesseract
from pdf2image import convert_from_path
import queue
import time
# -------------------- PDF to Markdown ๋ณ€ํ™˜ ๊ด€๋ จ import --------------------
try:
import re
import requests
from bs4 import BeautifulSoup
import urllib.request
import ocrmypdf
import pytz
import urllib.parse
from pypdf import PdfReader
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
"ํ•„์ˆ˜ ๋ชจ๋“ˆ์ด ๋ˆ„๋ฝ๋˜์—ˆ์Šต๋‹ˆ๋‹ค. 'beautifulsoup4' ํŒจํ‚ค์ง€๋ฅผ ์„ค์น˜ํ•ด์ฃผ์„ธ์š”.\n"
"์˜ˆ: pip install beautifulsoup4"
)
# ---------------------------------------------------------------------------
# ์ „์—ญ ๋ณ€์ˆ˜
current_file_context = None
# ํ™˜๊ฒฝ ๋ณ€์ˆ˜ ์„ค์ •
HF_TOKEN = os.environ.get("HF_TOKEN", None)
MODEL_ID = "CohereForAI/c4ai-command-r7b-12-2024"
MODEL_NAME = MODEL_ID.split("/")[-1]
model = None # ์ „์—ญ์—์„œ ๊ด€๋ฆฌ
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
# (1) ์œ„ํ‚คํ”ผ๋””์•„ ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ
wiki_dataset = load_dataset("lcw99/wikipedia-korean-20240501-1million-qna")
print("Wikipedia dataset loaded:", wiki_dataset)
# (2) TF-IDF ๋ฒกํ„ฐ๋ผ์ด์ € ์ดˆ๊ธฐํ™” ๋ฐ ํ•™์Šต (์ผ๋ถ€๋งŒ ์‚ฌ์šฉ)
print("TF-IDF ๋ฒกํ„ฐํ™” ์‹œ์ž‘...")
questions = wiki_dataset['train']['question'][:10000]
vectorizer = TfidfVectorizer(max_features=1000)
question_vectors = vectorizer.fit_transform(questions)
print("TF-IDF ๋ฒกํ„ฐํ™” ์™„๋ฃŒ")
# ------------------------- ChatHistory ํด๋ž˜์Šค -------------------------
class ChatHistory:
def __init__(self):
self.history = []
self.history_file = "/tmp/chat_history.json"
self.load_history()
def add_conversation(self, user_msg: str, assistant_msg: str):
conversation = {
"timestamp": datetime.now().isoformat(),
"messages": [
{"role": "user", "content": user_msg},
{"role": "assistant", "content": assistant_msg}
]
}
self.history.append(conversation)
self.save_history()
def format_for_display(self):
formatted = []
for conv in self.history:
formatted.append([
conv["messages"][0]["content"],
conv["messages"][1]["content"]
])
return formatted
def get_messages_for_api(self):
messages = []
for conv in self.history:
messages.extend([
{"role": "user", "content": conv["messages"][0]["content"]},
{"role": "assistant", "content": conv["messages"][1]["content"]}
])
return messages
def clear_history(self):
self.history = []
self.save_history()
def save_history(self):
try:
with open(self.history_file, 'w', encoding='utf-8') as f:
json.dump(self.history, f, ensure_ascii=False, indent=2)
except Exception as e:
print(f"ํžˆ์Šคํ† ๋ฆฌ ์ €์žฅ ์‹คํŒจ: {e}")
def load_history(self):
try:
if os.path.exists(self.history_file):
with open(self.history_file, 'r', encoding='utf-8') as f:
self.history = json.load(f)
except Exception as e:
print(f"ํžˆ์Šคํ† ๋ฆฌ ๋กœ๋“œ ์‹คํŒจ: {e}")
self.history = []
chat_history = ChatHistory()
# ------------------------- ์œ„ํ‚ค ๋ฌธ์„œ ๊ฒ€์ƒ‰ (TF-IDF) -------------------------
def find_relevant_context(query, top_k=3):
query_vector = vectorizer.transform([query])
similarities = (query_vector * question_vectors.T).toarray()[0]
top_indices = np.argsort(similarities)[-top_k:][::-1]
relevant_contexts = []
for idx in top_indices:
if similarities[idx] > 0:
relevant_contexts.append({
'question': questions[idx],
'answer': wiki_dataset['train']['answer'][idx],
'similarity': similarities[idx]
})
return relevant_contexts
def init_msg():
return "ํŒŒ์ผ์„ ๋ถ„์„ํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค..."
# -------------------- PDF ํŒŒ์ผ์„ Markdown์œผ๋กœ ๋ณ€ํ™˜ํ•˜๋Š” ์œ ํ‹ธ ํ•จ์ˆ˜๋“ค --------------------
def extract_text_from_pdf(reader: PdfReader) -> str:
full_text = ""
for idx, page in enumerate(reader.pages):
text = page.extract_text() or ""
if len(text) > 0:
full_text += f"---- Page {idx+1} ----\n" + text + "\n\n"
return full_text.strip()
def convert_pdf_to_markdown(pdf_file: str):
try:
reader = PdfReader(pdf_file)
except Exception as e:
return f"PDF ํŒŒ์ผ์„ ์ฝ๋Š” ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}", None, None
raw_meta = reader.metadata
metadata = {
"author": raw_meta.author if raw_meta else None,
"creator": raw_meta.creator if raw_meta else None,
"producer": raw_meta.producer if raw_meta else None,
"subject": raw_meta.subject if raw_meta else None,
"title": raw_meta.title if raw_meta else None,
}
full_text = extract_text_from_pdf(reader)
image_count = sum(len(page.images) for page in reader.pages)
if image_count > 0 and len(full_text) < 1000:
try:
out_pdf_file = pdf_file.replace(".pdf", "_ocr.pdf")
ocrmypdf.ocr(pdf_file, out_pdf_file, force_ocr=True)
reader_ocr = PdfReader(out_pdf_file)
full_text = extract_text_from_pdf(reader_ocr)
except Exception as e:
full_text = f"OCR ์ฒ˜๋ฆฌ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}\n\n์›๋ณธ PDF ํ…์ŠคํŠธ:\n\n" + full_text
return full_text, metadata, pdf_file
# ------------------------- ํŒŒ์ผ ๋ถ„์„ ํ•จ์ˆ˜ -------------------------
def analyze_file_content(content, file_type):
if file_type in ['parquet', 'csv']:
try:
lines = content.split('\n')
header = lines[0]
columns = header.count('|') - 1
rows = len(lines) - 3
return f"๐Ÿ“Š Dataset Structure: {columns} columns, {rows} rows"
except:
return "โŒ Failed to analyze dataset structure"
lines = content.split('\n')
total_lines = len(lines)
non_empty_lines = len([line for line in lines if line.strip()])
if any(keyword in content.lower() for keyword in ['def ', 'class ', 'import ', 'function']):
functions = len([line for line in lines if 'def ' in line])
classes = len([line for line in lines if 'class ' in line])
imports = len([line for line in lines if 'import ' in line or 'from ' in line])
return f"๐Ÿ’ป Code Structure: {total_lines} lines (Functions: {functions}, Classes: {classes}, Imports: {imports})"
paragraphs = content.count('\n\n') + 1
words = len(content.split())
return f"๐Ÿ“ Document Structure: {total_lines} lines, {paragraphs} paragraphs, approximately {words} words"
def read_uploaded_file(file):
if file is None:
return "", ""
import pyarrow.parquet as pq
import pandas as pd
from tabulate import tabulate
try:
file_ext = os.path.splitext(file.name)[1].lower()
if file_ext == '.parquet':
try:
table = pq.read_table(file.name)
df = table.to_pandas()
content = f"๐Ÿ“Š Parquet File Analysis:\n\n"
content += f"1. Basic Information:\n"
content += f"- Total Rows: {len(df):,}\n"
content += f"- Total Columns: {len(df.columns)}\n"
mem_usage = df.memory_usage(deep=True).sum() / 1024 / 1024
content += f"- Memory Usage: {mem_usage:.2f} MB\n\n"
content += f"2. Column Information:\n"
for col in df.columns:
content += f"- {col} ({df[col].dtype})\n"
content += f"\n3. Data Preview:\n"
content += tabulate(df.head(5), headers='keys', tablefmt='pipe', showindex=False)
content += f"\n\n4. Missing Values:\n"
null_counts = df.isnull().sum()
for col, count in null_counts[null_counts > 0].items():
rate = count / len(df) * 100
content += f"- {col}: {count:,} ({rate:.1f}%)\n"
numeric_cols = df.select_dtypes(include=['int64', 'float64']).columns
if len(numeric_cols) > 0:
content += f"\n5. Numeric Column Statistics:\n"
stats_df = df[numeric_cols].describe()
content += tabulate(stats_df, headers='keys', tablefmt='pipe')
return content, "parquet"
except Exception as e:
return f"Error reading Parquet file: {str(e)}", "error"
elif file_ext == '.pdf':
try:
markdown_text, metadata, processed_pdf_path = convert_pdf_to_markdown(file.name)
if metadata is None:
return f"PDF ํŒŒ์ผ ๋ณ€ํ™˜ ์˜ค๋ฅ˜ ๋˜๋Š” ์ฝ๊ธฐ ์‹คํŒจ.\n\n์›๋ณธ ๋ฉ”์‹œ์ง€:\n{markdown_text}", "error"
content = "# PDF to Markdown Conversion\n\n"
content += "## Metadata\n"
for k, v in metadata.items():
content += f"**{k.capitalize()}**: {v}\n\n"
content += "## Extracted Text\n\n"
content += markdown_text
return content, "pdf"
except Exception as e:
return f"Error reading PDF file: {str(e)}", "error"
elif file_ext == '.csv':
encodings = ['utf-8', 'cp949', 'euc-kr', 'latin1']
for encoding in encodings:
try:
df = pd.read_csv(file.name, encoding=encoding)
content = f"๐Ÿ“Š CSV File Analysis:\n\n"
content += f"1. Basic Information:\n"
content += f"- Total Rows: {len(df):,}\n"
content += f"- Total Columns: {len(df.columns)}\n"
mem_usage = df.memory_usage(deep=True).sum() / 1024 / 1024
content += f"- Memory Usage: {mem_usage:.2f} MB\n\n"
content += f"2. Column Information:\n"
for col in df.columns:
content += f"- {col} ({df[col].dtype})\n"
content += f"\n3. Data Preview:\n"
content += df.head(5).to_markdown(index=False)
content += f"\n\n4. Missing Values:\n"
null_counts = df.isnull().sum()
for col, count in null_counts[null_counts > 0].items():
rate = count / len(df) * 100
content += f"- {col}: {count:,} ({rate:.1f}%)\n"
return content, "csv"
except UnicodeDecodeError:
continue
raise UnicodeDecodeError(
f"Unable to read file with supported encodings ({', '.join(encodings)})"
)
else:
encodings = ['utf-8', 'cp949', 'euc-kr', 'latin1']
for encoding in encodings:
try:
with open(file.name, 'r', encoding=encoding) as f:
content = f.read()
lines = content.split('\n')
total_lines = len(lines)
non_empty_lines = len([line for line in lines if line.strip()])
is_code = any(
keyword in content.lower()
for keyword in ['def ', 'class ', 'import ', 'function']
)
analysis = "\n๐Ÿ“ File Analysis:\n"
if is_code:
functions = sum('def ' in line for line in lines)
classes = sum('class ' in line for line in lines)
imports = sum(
('import ' in line) or ('from ' in line)
for line in lines
)
analysis += f"- File Type: Code\n"
analysis += f"- Total Lines: {total_lines:,}\n"
analysis += f"- Functions: {functions}\n"
analysis += f"- Classes: {classes}\n"
analysis += f"- Import Statements: {imports}\n"
else:
words = len(content.split())
chars = len(content)
analysis += f"- File Type: Text\n"
analysis += f"- Total Lines: {total_lines:,}\n"
analysis += f"- Non-empty Lines: {non_empty_lines:,}\n"
analysis += f"- Word Count: {words:,}\n"
analysis += f"- Character Count: {chars:,}\n"
return content + analysis, "text"
except UnicodeDecodeError:
continue
raise UnicodeDecodeError(
f"Unable to read file with supported encodings ({', '.join(encodings)})"
)
except Exception as e:
return f"Error reading file: {str(e)}", "error"
# ------------------------- CSS -------------------------
CSS = """
/* (์ƒ๋žต: ๋™์ผ) */
"""
def clear_cuda_memory():
if hasattr(torch.cuda, 'empty_cache'):
with torch.cuda.device('cuda'):
torch.cuda.empty_cache()
# ------------------------- ๋ชจ๋ธ ๋กœ๋”ฉ ํ•จ์ˆ˜ -------------------------
@spaces.GPU
def load_model():
try:
clear_cuda_memory()
loaded_model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
device_map="auto",
low_cpu_mem_usage=True,
)
# (์ค‘์š”) ๋ชจ๋ธ ๊ธฐ๋ณธ config์—์„œ๋„ ์บ์‹œ ์‚ฌ์šฉ ๊บผ๋‘˜ ์ˆ˜ ์žˆ์Œ
loaded_model.config.use_cache = False
return loaded_model
except Exception as e:
print(f"๋ชจ๋ธ ๋กœ๋“œ ์˜ค๋ฅ˜: {str(e)}")
raise
def build_prompt(conversation: list) -> str:
prompt = ""
for msg in conversation:
if msg["role"] == "user":
prompt += "User: " + msg["content"] + "\n"
elif msg["role"] == "assistant":
prompt += "Assistant: " + msg["content"] + "\n"
prompt += "Assistant: "
return prompt
# ------------------------- ๋ฉ”์‹œ์ง€ ์ŠคํŠธ๋ฆฌ๋ฐ ํ•จ์ˆ˜ -------------------------
@spaces.GPU
def stream_chat(
message: str,
history: list,
uploaded_file,
temperature: float,
max_new_tokens: int,
top_p: float,
top_k: int,
penalty: float
):
global model, current_file_context
try:
if model is None:
model = load_model()
print(f'[User input] message: {message}')
print(f'[History] {history}')
# 1) ํŒŒ์ผ ์—…๋กœ๋“œ ์ฒ˜๋ฆฌ
file_context = ""
if uploaded_file and message == "ํŒŒ์ผ์„ ๋ถ„์„ํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค...":
current_file_context = None
try:
content, file_type = read_uploaded_file(uploaded_file)
if content:
file_analysis = analyze_file_content(content, file_type)
file_context = (
f"\n\n๐Ÿ“„ ํŒŒ์ผ ๋ถ„์„ ๊ฒฐ๊ณผ:\n{file_analysis}"
f"\n\nํŒŒ์ผ ๋‚ด์šฉ:\n```\n{content}\n```"
)
current_file_context = file_context
message = "์—…๋กœ๋“œ๋œ ํŒŒ์ผ์„ ๋ถ„์„ํ•ด์ฃผ์„ธ์š”."
except Exception as e:
print(f"[ํŒŒ์ผ ๋ถ„์„ ์˜ค๋ฅ˜] {str(e)}")
file_context = f"\n\nโŒ ํŒŒ์ผ ๋ถ„์„ ์ค‘ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {str(e)}"
elif current_file_context:
file_context = current_file_context
# 2) ์œ„ํ‚ค ์ปจํ…์ŠคํŠธ
wiki_context = ""
try:
relevant_contexts = find_relevant_context(message)
if relevant_contexts:
wiki_context = "\n\n๊ด€๋ จ ์œ„ํ‚คํ”ผ๋””์•„ ์ •๋ณด:\n"
for ctx in relevant_contexts:
wiki_context += (
f"Q: {ctx['question']}\n"
f"A: {ctx['answer']}\n"
f"์œ ์‚ฌ๋„: {ctx['similarity']:.3f}\n\n"
)
except Exception as e:
print(f"[์ปจํ…์ŠคํŠธ ๊ฒ€์ƒ‰ ์˜ค๋ฅ˜] {str(e)}")
# 3) ๋Œ€ํ™” ์ด๋ ฅ ์ถ•์†Œ
max_history_length = 10
if len(history) > max_history_length:
history = history[-max_history_length:]
conversation = []
for prompt, answer in history:
conversation.extend([
{"role": "user", "content": prompt},
{"role": "assistant", "content": answer}
])
# 4) ์ตœ์ข… ๋ฉ”์‹œ์ง€
final_message = message
if file_context:
final_message = file_context + "\nํ˜„์žฌ ์งˆ๋ฌธ: " + message
if wiki_context:
final_message = wiki_context + "\nํ˜„์žฌ ์งˆ๋ฌธ: " + message
if file_context and wiki_context:
final_message = file_context + wiki_context + "\nํ˜„์žฌ ์งˆ๋ฌธ: " + message
conversation.append({"role": "user", "content": final_message})
# 5) ํ† ํฐํ™”
input_ids_str = build_prompt(conversation)
max_context = 8192
tokenized_input = tokenizer(input_ids_str, return_tensors="pt")
input_length = tokenized_input["input_ids"].shape[1]
# 6) ์ปจํ…์ŠคํŠธ ์ดˆ๊ณผ ์‹œ ์ž๋ฅด๊ธฐ
if input_length > max_context - max_new_tokens:
print(f"[๊ฒฝ๊ณ ] ์ž…๋ ฅ์ด ๋„ˆ๋ฌด ๊น๋‹ˆ๋‹ค: {input_length} ํ† ํฐ -> ์ž˜๋ผ๋ƒ„.")
min_generation = min(256, max_new_tokens)
new_desired_input_length = max_context - min_generation
tokens = tokenizer.encode(input_ids_str)
if len(tokens) > new_desired_input_length:
tokens = tokens[-new_desired_input_length:]
input_ids_str = tokenizer.decode(tokens)
tokenized_input = tokenizer(input_ids_str, return_tensors="pt")
input_length = tokenized_input["input_ids"].shape[1]
print(f"[ํ† ํฐ ๊ธธ์ด] {input_length}")
inputs = tokenized_input.to("cuda")
# 7) ๋‚จ์€ ํ† ํฐ ์ˆ˜๋กœ max_new_tokens ๋ณด์ •
remaining = max_context - input_length
if remaining < max_new_tokens:
print(f"[max_new_tokens ์กฐ์ •] {max_new_tokens} -> {remaining}")
max_new_tokens = remaining
# 8) TextIteratorStreamer ์„ค์ •
streamer = TextIteratorStreamer(
tokenizer, timeout=30.0, skip_prompt=True, skip_special_tokens=True
)
# โ˜… use_cache=False ์„ค์ • (์ค‘์š”) โ˜…
generate_kwargs = dict(
**inputs,
streamer=streamer,
top_k=top_k,
top_p=top_p,
repetition_penalty=penalty,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
use_cache=False, # โ† ์—ฌ๊ธฐ๊ฐ€ ํ•ต์‹ฌ!
)
clear_cuda_memory()
# 9) ๋ณ„๋„ ์Šค๋ ˆ๋“œ๋กœ ๋ชจ๋ธ ํ˜ธ์ถœ
thread = Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
# 10) ์ŠคํŠธ๋ฆฌ๋ฐ
buffer = ""
partial_message = ""
last_yield_time = time.time()
try:
for new_text in streamer:
buffer += new_text
partial_message += new_text
# ํƒ€์ด๋ฐ or ์ผ์ • ๊ธธ์ด๋งˆ๋‹ค UI ์—…๋ฐ์ดํŠธ
current_time = time.time()
if (current_time - last_yield_time > 0.1) or (len(partial_message) > 20):
yield "", history + [[message, buffer]]
partial_message = ""
last_yield_time = current_time
# ๋งˆ์ง€๋ง‰ ์ถœ๋ ฅ
if buffer:
yield "", history + [[message, buffer]]
# ๋Œ€ํ™” ํžˆ์Šคํ† ๋ฆฌ ์ €์žฅ
chat_history.add_conversation(message, buffer)
except Exception as e:
print(f"[์ŠคํŠธ๋ฆฌ๋ฐ ์ค‘ ์˜ค๋ฅ˜] {str(e)}")
if not buffer:
buffer = f"์‘๋‹ต ์ƒ์„ฑ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}"
yield "", history + [[message, buffer]]
if thread.is_alive():
thread.join(timeout=5.0)
clear_cuda_memory()
except Exception as e:
import traceback
error_details = traceback.format_exc()
error_message = f"์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {str(e)}\n{error_details}"
print(f"[Stream chat ์˜ค๋ฅ˜] {error_message}")
clear_cuda_memory()
yield "", history + [[message, error_message]]
# ------------------------- Gradio UI ๊ตฌ์„ฑ -------------------------
def create_demo():
with gr.Blocks(css=CSS) as demo:
with gr.Column(elem_classes="markdown-style"):
gr.Markdown("""
# ๐Ÿค– RAGOndevice
#### ๐Ÿ“Š RAG: Upload and Analyze Files (TXT, CSV, PDF, Parquet files)
Upload your files for data analysis and learning
""")
chatbot = gr.Chatbot(
value=[],
height=600,
label="GiniGEN AI Assistant",
elem_classes="chat-container"
)
with gr.Row(elem_classes="input-container"):
with gr.Column(scale=1, min_width=70):
file_upload = gr.File(
type="filepath",
elem_classes="file-upload-icon",
scale=1,
container=True,
interactive=True,
show_label=False
)
with gr.Column(scale=3):
msg = gr.Textbox(
show_label=False,
placeholder="Type your message here... ๐Ÿ’ญ",
container=False,
elem_classes="input-textbox",
scale=1
)
with gr.Column(scale=1, min_width=70):
send = gr.Button(
"Send",
elem_classes="send-button custom-button",
scale=1
)
with gr.Column(scale=1, min_width=70):
clear = gr.Button(
"Clear",
elem_classes="clear-button custom-button",
scale=1
)
# ๊ณ ๊ธ‰ ์„ค์ •
with gr.Accordion("๐ŸŽฎ Advanced Settings", open=False):
with gr.Row():
with gr.Column(scale=1):
temperature = gr.Slider(
minimum=0, maximum=1, step=0.1, value=0.8,
label="Creativity Level ๐ŸŽจ"
)
max_new_tokens = gr.Slider(
minimum=128, maximum=8000, step=1, value=4000,
label="Maximum Token Count ๐Ÿ“"
)
with gr.Column(scale=1):
top_p = gr.Slider(
minimum=0.0, maximum=1.0, step=0.1, value=0.8,
label="Diversity Control ๐ŸŽฏ"
)
top_k = gr.Slider(
minimum=1, maximum=20, step=1, value=20,
label="Selection Range ๐Ÿ“Š"
)
penalty = gr.Slider(
minimum=0.0, maximum=2.0, step=0.1, value=1.0,
label="Repetition Penalty ๐Ÿ”„"
)
# ์˜ˆ์‹œ ์ž…๋ ฅ
gr.Examples(
examples=[
["Please analyze this code and suggest improvements:\ndef fibonacci(n):\n if n <= 1: return n\n return fibonacci(n-1) + fibonacci(n-2)"],
["Please analyze this data and provide insights:\nAnnual Revenue (Million)\n2019: 1200\n2020: 980\n2021: 1450\n2022: 2100\n2023: 1890"],
["Please solve this math problem step by step: 'When a circle's area is twice that of its inscribed square, find the relationship between the circle's radius and the square's side length.'"],
["Please analyze this marketing campaign's ROI and suggest improvements:\nTotal Cost: $50,000\nReach: 1M users\nClick Rate: 2.3%\nConversion Rate: 0.8%\nAverage Purchase: $35"],
],
inputs=msg
)
# ๋Œ€ํ™” ๋‚ด์šฉ ์ดˆ๊ธฐํ™”
def clear_conversation():
global current_file_context
current_file_context = None
return [], None, "Start a new conversation..."
# ๋ฉ”์‹œ์ง€ ์ „์†ก(Submit)
msg.submit(
stream_chat,
inputs=[msg, chatbot, file_upload, temperature, max_new_tokens, top_p, top_k, penalty],
outputs=[msg, chatbot]
)
send.click(
stream_chat,
inputs=[msg, chatbot, file_upload, temperature, max_new_tokens, top_p, top_k, penalty],
outputs=[msg, chatbot]
)
# ํŒŒ์ผ ์—…๋กœ๋“œ ์ด๋ฒคํŠธ
file_upload.change(
fn=lambda: ("์ฒ˜๋ฆฌ ์ค‘...", [["์‹œ์Šคํ…œ", "ํŒŒ์ผ์„ ๋ถ„์„ ์ค‘์ž…๋‹ˆ๋‹ค. ์ž ์‹œ๋งŒ ๊ธฐ๋‹ค๋ ค์ฃผ์„ธ์š”..."]]),
outputs=[msg, chatbot],
queue=False
).then(
fn=init_msg,
outputs=msg,
queue=False
).then(
fn=stream_chat,
inputs=[msg, chatbot, file_upload, temperature, max_new_tokens, top_p, top_k, penalty],
outputs=[msg, chatbot],
queue=True
)
# Clear ๋ฒ„ํŠผ
clear.click(
fn=clear_conversation,
outputs=[chatbot, file_upload, msg],
queue=False
)
return demo
if __name__ == "__main__":
demo = create_demo()
demo.launch()