PierreBrunelle's picture
Update app.py
62421f1 verified
raw
history blame
16.5 kB
import gradio as gr
import pixeltable as pxt
from pixeltable.functions.mistralai import chat_completions
from datetime import datetime
from textblob import TextBlob
import re
import nltk
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
import os
import getpass
# Ensure necessary NLTK data is downloaded
nltk.download('punkt', quiet=True)
nltk.download('stopwords', quiet=True)
nltk.download('punkt_tab', quiet=True)
# Set up Mistral API key
if 'MISTRAL_API_KEY' not in os.environ:
os.environ['MISTRAL_API_KEY'] = getpass.getpass('Mistral AI API Key:')
# Define UDFs
@pxt.udf
def get_sentiment_score(text: str) -> float:
return TextBlob(text).sentiment.polarity
@pxt.udf
def extract_keywords(text: str, num_keywords: int = 5) -> list:
stop_words = set(stopwords.words('english'))
words = word_tokenize(text.lower())
keywords = [word for word in words if word.isalnum() and word not in stop_words]
return sorted(set(keywords), key=keywords.count, reverse=True)[:num_keywords]
@pxt.udf
def calculate_readability(text: str) -> float:
words = len(re.findall(r'\w+', text))
sentences = len(re.findall(r'\w+[.!?]', text)) or 1
average_words_per_sentence = words / sentences
return 206.835 - 1.015 * average_words_per_sentence
# Function to run inference and analysis
def run_inference_and_analysis(task, system_prompt, input_text, temperature, top_p, max_tokens, stop, random_seed, safe_prompt):
# Initialize Pixeltable
pxt.drop_table('mistral_prompts', ignore_errors=True)
t = pxt.create_table('mistral_prompts', {
'task': pxt.String,
'system': pxt.String,
'input_text': pxt.String,
'timestamp': pxt.Timestamp,
'temperature': pxt.Float,
'top_p': pxt.Float,
'max_tokens': pxt.Int,
'stop': pxt.String,
'random_seed': pxt.Int,
'safe_prompt': pxt.Bool
})
# Insert new row into Pixeltable
t.insert([{
'task': task,
'system': system_prompt,
'input_text': input_text,
'timestamp': datetime.now(),
'temperature': temperature,
'top_p': top_p,
'max_tokens': max_tokens,
'stop': stop,
'random_seed': random_seed,
'safe_prompt': safe_prompt
}])
# Define messages for chat completion
msgs = [
{'role': 'system', 'content': t.system},
{'role': 'user', 'content': t.input_text}
]
common_params = {
'messages': msgs,
'temperature': temperature,
'top_p': top_p,
'max_tokens': max_tokens if max_tokens is not None else 300,
'stop': stop.split(',') if stop else None,
'random_seed': random_seed,
'safe_prompt': safe_prompt
}
# Add computed columns for model responses and analysis
t.add_computed_column(open_mistral_nemo=chat_completions(model='open-mistral-nemo', **common_params))
t.add_computed_column(mistral_medium=chat_completions(model='mistral-medium', **common_params))
# Extract responses
t.add_computed_column(omn_response=t.open_mistral_nemo.choices[0].message.content.astype(pxt.String))
t.add_computed_column(ml_response=t.mistral_medium.choices[0].message.content.astype(pxt.String))
# Add computed columns for analysis
t.add_computed_column(large_sentiment_score=get_sentiment_score(t.ml_response))
t.add_computed_column(large_keywords=extract_keywords(t.ml_response))
t.add_computed_column(large_readability_score=calculate_readability(t.ml_response))
t.add_computed_column(open_sentiment_score=get_sentiment_score(t.omn_response))
t.add_computed_column(open_keywords=extract_keywords(t.omn_response))
t.add_computed_column(open_readability_score=calculate_readability(t.omn_response))
# Retrieve results
results = t.select(
t.omn_response, t.ml_response,
t.large_sentiment_score, t.open_sentiment_score,
t.large_keywords, t.open_keywords,
t.large_readability_score, t.open_readability_score
).tail(1)
history = t.select(t.timestamp, t.task, t.system, t.input_text).order_by(t.timestamp, asc=False).collect().to_pandas()
responses = t.select(t.timestamp, t.omn_response, t.ml_response).order_by(t.timestamp, asc=False).collect().to_pandas()
analysis = t.select(
t.timestamp,
t.open_sentiment_score,
t.large_sentiment_score,
t.open_keywords,
t.large_keywords,
t.open_readability_score,
t.large_readability_score
).order_by(t.timestamp, asc=False).collect().to_pandas()
params = t.select(
t.timestamp,
t.temperature,
t.top_p,
t.max_tokens,
t.stop,
t.random_seed,
t.safe_prompt
).order_by(t.timestamp, asc=False).collect().to_pandas()
return (
results['omn_response'][0],
results['ml_response'][0],
results['large_sentiment_score'][0],
results['open_sentiment_score'][0],
results['large_keywords'][0],
results['open_keywords'][0],
results['large_readability_score'][0],
results['open_readability_score'][0],
history,
responses,
analysis,
params
)
def gradio_interface():
with gr.Blocks(theme=gr.themes.Base(), title="Pixeltable LLM Studio") as demo:
# Enhanced Header with Branding
gr.HTML("""
<div style="text-align: center; padding: 20px; background: linear-gradient(to right, #4F46E5, #7C3AED);" class="shadow-lg">
<img src="https://raw.githubusercontent.com/pixeltable/pixeltable/main/docs/source/data/pixeltable-logo-large.png"
alt="Pixeltable" style="max-width: 200px; margin-bottom: 15px;" />
<h1 style="color: white; font-size: 2.5rem; margin-bottom: 10px;">LLM Studio</h1>
<p style="color: #E5E7EB; font-size: 1.1rem;">
Powered by Pixeltable's Unified AI Data Infrastructure
</p>
</div>
""")
# Product Overview Cards
with gr.Row():
with gr.Column():
gr.HTML("""
<div style="padding: 20px; background-color: white; border-radius: 10px; box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); margin: 10px;">
<h3 style="color: #4F46E5; margin-bottom: 10px;">πŸš€ Why Pixeltable?</h3>
<ul style="list-style-type: none; padding-left: 0;">
<li style="margin-bottom: 8px;">✨ Unified data management for AI workflows</li>
<li style="margin-bottom: 8px;">πŸ“Š Automatic versioning and lineage tracking</li>
<li style="margin-bottom: 8px;">⚑ Seamless model integration and deployment</li>
<li style="margin-bottom: 8px;">πŸ” Advanced querying and analysis capabilities</li>
</ul>
</div>
""")
with gr.Column():
gr.HTML("""
<div style="padding: 20px; background-color: white; border-radius: 10px; box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); margin: 10px;">
<h3 style="color: #4F46E5; margin-bottom: 10px;">πŸ’‘ Features</h3>
<ul style="list-style-type: none; padding-left: 0;">
<li style="margin-bottom: 8px;">πŸ”„ Compare multiple LLM models side-by-side</li>
<li style="margin-bottom: 8px;">πŸ“ˆ Track and analyze model performance</li>
<li style="margin-bottom: 8px;">🎯 Experiment with different prompts and parameters</li>
<li style="margin-bottom: 8px;">πŸ“ Automatic analysis with sentiment and readability scores</li>
</ul>
</div>
""")
# Main Interface
with gr.Tabs() as tabs:
with gr.TabItem("🎯 Experiment", id=0):
with gr.Row():
with gr.Column(scale=1):
gr.HTML("""
<div style="padding: 15px; background-color: #F3F4F6; border-radius: 8px; margin-bottom: 15px;">
<h3 style="color: #4F46E5; margin-bottom: 10px;">Experiment Setup</h3>
<p style="color: #6B7280; font-size: 0.9rem;">Configure your prompt engineering experiment below</p>
</div>
""")
task = gr.Textbox(
label="Task Category",
placeholder="e.g., Sentiment Analysis, Text Generation, Summarization",
elem_classes="input-style"
)
system_prompt = gr.Textbox(
label="System Prompt",
placeholder="Define the AI's role and task...",
lines=3,
elem_classes="input-style"
)
input_text = gr.Textbox(
label="Input Text",
placeholder="Enter your prompt or text to analyze...",
lines=4,
elem_classes="input-style"
)
with gr.Accordion("πŸ› οΈ Advanced Settings", open=False):
temperature = gr.Slider(minimum=0, maximum=1, value=0.7, step=0.1, label="Temperature")
top_p = gr.Slider(minimum=0, maximum=1, value=0.9, step=0.1, label="Top P")
max_tokens = gr.Number(label="Max Tokens", value=300)
stop = gr.Textbox(label="Stop Sequences (comma-separated)")
random_seed = gr.Number(label="Random Seed", value=None)
safe_prompt = gr.Checkbox(label="Safe Prompt", value=False)
submit_btn = gr.Button(
"πŸš€ Run Analysis",
variant="primary",
scale=1,
min_width=200
)
with gr.Column(scale=1):
gr.HTML("""
<div style="padding: 15px; background-color: #F3F4F6; border-radius: 8px; margin-bottom: 15px;">
<h3 style="color: #4F46E5; margin-bottom: 10px;">Results</h3>
<p style="color: #6B7280; font-size: 0.9rem;">Compare model outputs and analysis metrics</p>
</div>
""")
with gr.Group():
omn_response = gr.Textbox(
label="Open-Mistral-Nemo Response",
elem_classes="output-style"
)
ml_response = gr.Textbox(
label="Mistral-Medium Response",
elem_classes="output-style"
)
with gr.Group():
with gr.Row():
with gr.Column():
gr.HTML("<h4>πŸ“Š Sentiment Analysis</h4>")
large_sentiment = gr.Number(label="Mistral-Medium")
open_sentiment = gr.Number(label="Open-Mistral-Nemo")
with gr.Column():
gr.HTML("<h4>πŸ“ˆ Readability Scores</h4>")
large_readability = gr.Number(label="Mistral-Medium")
open_readability = gr.Number(label="Open-Mistral-Nemo")
gr.HTML("<h4>πŸ”‘ Key Terms</h4>")
with gr.Row():
large_keywords = gr.Textbox(label="Mistral-Medium Keywords")
open_keywords = gr.Textbox(label="Open-Mistral-Nemo Keywords")
with gr.TabItem("πŸ“Š History & Analysis", id=1):
with gr.Tabs():
with gr.TabItem("Prompt History"):
history = gr.DataFrame(
headers=["Timestamp", "Task", "System Prompt", "Input Text"],
wrap=True,
elem_classes="table-style"
)
with gr.TabItem("Model Responses"):
responses = gr.DataFrame(
headers=["Timestamp", "Open-Mistral-Nemo", "Mistral-Medium"],
wrap=True,
elem_classes="table-style"
)
with gr.TabItem("Analysis Results"):
analysis = gr.DataFrame(
headers=[
"Timestamp",
"Open-Mistral-Nemo Sentiment",
"Mistral-Medium Sentiment",
"Open-Mistral-Nemo Keywords",
"Mistral-Medium Keywords",
"Open-Mistral-Nemo Readability",
"Mistral-Medium Readability"
],
wrap=True,
elem_classes="table-style"
)
# Footer with links and additional info
gr.HTML("""
<div style="text-align: center; padding: 20px; margin-top: 30px; border-top: 1px solid #E5E7EB;">
<div style="margin-bottom: 20px;">
<h3 style="color: #4F46E5;">Built with Pixeltable</h3>
<p style="color: #6B7280;">The unified data infrastructure for AI applications</p>
</div>
<div style="display: flex; justify-content: center; gap: 20px;">
<a href="https://github.com/pixeltable/pixeltable" target="_blank"
style="color: #4F46E5; text-decoration: none;">
πŸ“š Documentation
</a>
<a href="https://github.com/pixeltable/pixeltable" target="_blank"
style="color: #4F46E5; text-decoration: none;">
πŸ’» GitHub
</a>
<a href="https://join.slack.com/t/pixeltablecommunity/shared_invite/zt-21fybjbn2-fZC_SJiuG6QL~Ai8T6VpFQ" target="_blank"
style="color: #4F46E5; text-decoration: none;">
πŸ’¬ Community
</a>
</div>
</div>
""")
# Custom CSS
gr.HTML("""
<style>
.input-style {
border: 1px solid #E5E7EB !important;
border-radius: 8px !important;
padding: 12px !important;
}
.output-style {
background-color: #F9FAFB !important;
border-radius: 8px !important;
padding: 12px !important;
}
.table-style {
border-collapse: collapse !important;
width: 100% !important;
}
.table-style th {
background-color: #F3F4F6 !important;
padding: 12px !important;
}
</style>
""")
# Setup event handlers
submit_btn.click(
run_inference_and_analysis,
inputs=[task, system_prompt, input_text, temperature, top_p, max_tokens, stop, random_seed, safe_prompt],
outputs=[omn_response, ml_response, large_sentiment, open_sentiment, large_keywords, open_keywords,
large_readability, open_readability, history, responses, analysis, params]
)
return demo
# Launch the Gradio interface
if __name__ == "__main__":
gradio_interface().launch()