|
import gradio as gr |
|
import pixeltable as pxt |
|
from pixeltable.functions.mistralai import chat_completions |
|
from datetime import datetime |
|
from textblob import TextBlob |
|
import re |
|
import nltk |
|
from nltk.tokenize import word_tokenize |
|
from nltk.corpus import stopwords |
|
import os |
|
import getpass |
|
|
|
|
|
nltk.download('punkt', quiet=True) |
|
nltk.download('stopwords', quiet=True) |
|
nltk.download('punkt_tab', quiet=True) |
|
|
|
|
|
if 'MISTRAL_API_KEY' not in os.environ: |
|
os.environ['MISTRAL_API_KEY'] = getpass.getpass('Mistral AI API Key:') |
|
|
|
|
|
@pxt.udf |
|
def get_sentiment_score(text: str) -> float: |
|
return TextBlob(text).sentiment.polarity |
|
|
|
@pxt.udf |
|
def extract_keywords(text: str, num_keywords: int = 5) -> list: |
|
stop_words = set(stopwords.words('english')) |
|
words = word_tokenize(text.lower()) |
|
keywords = [word for word in words if word.isalnum() and word not in stop_words] |
|
return sorted(set(keywords), key=keywords.count, reverse=True)[:num_keywords] |
|
|
|
@pxt.udf |
|
def calculate_readability(text: str) -> float: |
|
words = len(re.findall(r'\w+', text)) |
|
sentences = len(re.findall(r'\w+[.!?]', text)) or 1 |
|
average_words_per_sentence = words / sentences |
|
return 206.835 - 1.015 * average_words_per_sentence |
|
|
|
|
|
def run_inference_and_analysis(task, system_prompt, input_text, temperature, top_p, max_tokens, stop, random_seed, safe_prompt): |
|
|
|
pxt.drop_table('mistral_prompts', ignore_errors=True) |
|
t = pxt.create_table('mistral_prompts', { |
|
'task': pxt.String, |
|
'system': pxt.String, |
|
'input_text': pxt.String, |
|
'timestamp': pxt.Timestamp, |
|
'temperature': pxt.Float, |
|
'top_p': pxt.Float, |
|
'max_tokens': pxt.Int, |
|
'stop': pxt.String, |
|
'random_seed': pxt.Int, |
|
'safe_prompt': pxt.Bool |
|
}) |
|
|
|
|
|
t.insert([{ |
|
'task': task, |
|
'system': system_prompt, |
|
'input_text': input_text, |
|
'timestamp': datetime.now(), |
|
'temperature': temperature, |
|
'top_p': top_p, |
|
'max_tokens': max_tokens, |
|
'stop': stop, |
|
'random_seed': random_seed, |
|
'safe_prompt': safe_prompt |
|
}]) |
|
|
|
|
|
msgs = [ |
|
{'role': 'system', 'content': t.system}, |
|
{'role': 'user', 'content': t.input_text} |
|
] |
|
|
|
common_params = { |
|
'messages': msgs, |
|
'temperature': temperature, |
|
'top_p': top_p, |
|
'max_tokens': max_tokens if max_tokens is not None else 300, |
|
'stop': stop.split(',') if stop else None, |
|
'random_seed': random_seed, |
|
'safe_prompt': safe_prompt |
|
} |
|
|
|
|
|
t.add_computed_column(open_mistral_nemo=chat_completions(model='open-mistral-nemo', **common_params)) |
|
t.add_computed_column(mistral_medium=chat_completions(model='mistral-medium', **common_params)) |
|
|
|
|
|
t.add_computed_column(omn_response=t.open_mistral_nemo.choices[0].message.content.astype(pxt.String)) |
|
t.add_computed_column(ml_response=t.mistral_medium.choices[0].message.content.astype(pxt.String)) |
|
|
|
|
|
t.add_computed_column(large_sentiment_score=get_sentiment_score(t.ml_response)) |
|
t.add_computed_column(large_keywords=extract_keywords(t.ml_response)) |
|
t.add_computed_column(large_readability_score=calculate_readability(t.ml_response)) |
|
t.add_computed_column(open_sentiment_score=get_sentiment_score(t.omn_response)) |
|
t.add_computed_column(open_keywords=extract_keywords(t.omn_response)) |
|
t.add_computed_column(open_readability_score=calculate_readability(t.omn_response)) |
|
|
|
|
|
results = t.select( |
|
t.omn_response, t.ml_response, |
|
t.large_sentiment_score, t.open_sentiment_score, |
|
t.large_keywords, t.open_keywords, |
|
t.large_readability_score, t.open_readability_score |
|
).tail(1) |
|
|
|
history = t.select(t.timestamp, t.task, t.system, t.input_text).order_by(t.timestamp, asc=False).collect().to_pandas() |
|
responses = t.select(t.timestamp, t.omn_response, t.ml_response).order_by(t.timestamp, asc=False).collect().to_pandas() |
|
analysis = t.select( |
|
t.timestamp, |
|
t.open_sentiment_score, |
|
t.large_sentiment_score, |
|
t.open_keywords, |
|
t.large_keywords, |
|
t.open_readability_score, |
|
t.large_readability_score |
|
).order_by(t.timestamp, asc=False).collect().to_pandas() |
|
params = t.select( |
|
t.timestamp, |
|
t.temperature, |
|
t.top_p, |
|
t.max_tokens, |
|
t.stop, |
|
t.random_seed, |
|
t.safe_prompt |
|
).order_by(t.timestamp, asc=False).collect().to_pandas() |
|
|
|
return ( |
|
results['omn_response'][0], |
|
results['ml_response'][0], |
|
results['large_sentiment_score'][0], |
|
results['open_sentiment_score'][0], |
|
results['large_keywords'][0], |
|
results['open_keywords'][0], |
|
results['large_readability_score'][0], |
|
results['open_readability_score'][0], |
|
history, |
|
responses, |
|
analysis, |
|
params |
|
) |
|
|
|
def gradio_interface(): |
|
with gr.Blocks(theme=gr.themes.Base(), title="Pixeltable LLM Studio") as demo: |
|
|
|
gr.HTML(""" |
|
<div style="text-align: center; padding: 20px; background: linear-gradient(to right, #4F46E5, #7C3AED);" class="shadow-lg"> |
|
<img src="https://raw.githubusercontent.com/pixeltable/pixeltable/main/docs/source/data/pixeltable-logo-large.png" |
|
alt="Pixeltable" style="max-width: 200px; margin-bottom: 15px;" /> |
|
<h1 style="color: white; font-size: 2.5rem; margin-bottom: 10px;">LLM Studio</h1> |
|
<p style="color: #E5E7EB; font-size: 1.1rem;"> |
|
Powered by Pixeltable's Unified AI Data Infrastructure |
|
</p> |
|
</div> |
|
""") |
|
|
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.HTML(""" |
|
<div style="padding: 20px; background-color: white; border-radius: 10px; box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); margin: 10px;"> |
|
<h3 style="color: #4F46E5; margin-bottom: 10px;">π Why Pixeltable?</h3> |
|
<ul style="list-style-type: none; padding-left: 0;"> |
|
<li style="margin-bottom: 8px;">β¨ Unified data management for AI workflows</li> |
|
<li style="margin-bottom: 8px;">π Automatic versioning and lineage tracking</li> |
|
<li style="margin-bottom: 8px;">β‘ Seamless model integration and deployment</li> |
|
<li style="margin-bottom: 8px;">π Advanced querying and analysis capabilities</li> |
|
</ul> |
|
</div> |
|
""") |
|
|
|
with gr.Column(): |
|
gr.HTML(""" |
|
<div style="padding: 20px; background-color: white; border-radius: 10px; box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); margin: 10px;"> |
|
<h3 style="color: #4F46E5; margin-bottom: 10px;">π‘ Features</h3> |
|
<ul style="list-style-type: none; padding-left: 0;"> |
|
<li style="margin-bottom: 8px;">π Compare multiple LLM models side-by-side</li> |
|
<li style="margin-bottom: 8px;">π Track and analyze model performance</li> |
|
<li style="margin-bottom: 8px;">π― Experiment with different prompts and parameters</li> |
|
<li style="margin-bottom: 8px;">π Automatic analysis with sentiment and readability scores</li> |
|
</ul> |
|
</div> |
|
""") |
|
|
|
|
|
with gr.Tabs() as tabs: |
|
with gr.TabItem("π― Experiment", id=0): |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
gr.HTML(""" |
|
<div style="padding: 15px; background-color: #F3F4F6; border-radius: 8px; margin-bottom: 15px;"> |
|
<h3 style="color: #4F46E5; margin-bottom: 10px;">Experiment Setup</h3> |
|
<p style="color: #6B7280; font-size: 0.9rem;">Configure your prompt engineering experiment below</p> |
|
</div> |
|
""") |
|
|
|
task = gr.Textbox( |
|
label="Task Category", |
|
placeholder="e.g., Sentiment Analysis, Text Generation, Summarization", |
|
elem_classes="input-style" |
|
) |
|
system_prompt = gr.Textbox( |
|
label="System Prompt", |
|
placeholder="Define the AI's role and task...", |
|
lines=3, |
|
elem_classes="input-style" |
|
) |
|
input_text = gr.Textbox( |
|
label="Input Text", |
|
placeholder="Enter your prompt or text to analyze...", |
|
lines=4, |
|
elem_classes="input-style" |
|
) |
|
|
|
with gr.Accordion("π οΈ Advanced Settings", open=False): |
|
temperature = gr.Slider(minimum=0, maximum=1, value=0.7, step=0.1, label="Temperature") |
|
top_p = gr.Slider(minimum=0, maximum=1, value=0.9, step=0.1, label="Top P") |
|
max_tokens = gr.Number(label="Max Tokens", value=300) |
|
min_tokens = gr.Number(label="Min Tokens", value=None) |
|
stop = gr.Textbox(label="Stop Sequences (comma-separated)") |
|
random_seed = gr.Number(label="Random Seed", value=None) |
|
safe_prompt = gr.Checkbox(label="Safe Prompt", value=False) |
|
|
|
|
|
gr.HTML(""" |
|
<div style="padding: 15px; background-color: #F3F4F6; border-radius: 8px; margin: 20px 0;"> |
|
<h3 style="color: #4F46E5; margin-bottom: 10px;">π Example Prompts</h3> |
|
<p style="color: #6B7280; font-size: 0.9rem;">Try these pre-configured examples to get started</p> |
|
</div> |
|
""") |
|
|
|
examples = [ |
|
|
|
["Sentiment Analysis", |
|
"You are an AI trained to analyze the sentiment of text. Provide a detailed analysis of the emotional tone, highlighting key phrases that indicate sentiment.", |
|
"The new restaurant downtown exceeded all my expectations. The food was exquisite, the service impeccable, and the ambiance was perfect for a romantic evening. I can't wait to go back!", |
|
0.3, 0.95, 200, None, "", None, False], |
|
|
|
|
|
["Story Generation", |
|
"You are a creative writer. Generate a short, engaging story based on the given prompt. Include vivid descriptions and an unexpected twist.", |
|
"In a world where dreams are shared, a young girl discovers she can manipulate other people's dreams.", |
|
0.9, 0.8, 500, 300, "The end", None, False] |
|
] |
|
|
|
gr.Examples( |
|
examples=examples, |
|
inputs=[ |
|
task, system_prompt, input_text, |
|
temperature, top_p, max_tokens, |
|
min_tokens, stop, random_seed, |
|
safe_prompt |
|
], |
|
outputs=[ |
|
omn_response, ml_response, |
|
large_sentiment, open_sentiment, |
|
large_keywords, open_keywords, |
|
large_readability, open_readability |
|
], |
|
fn=run_inference_and_analysis, |
|
cache_examples=True, |
|
elem_classes="examples-style" |
|
) |
|
|
|
submit_btn = gr.Button( |
|
"π Run Analysis", |
|
variant="primary", |
|
scale=1, |
|
min_width=200 |
|
) |
|
|
|
with gr.Column(scale=1): |
|
gr.HTML(""" |
|
<div style="padding: 15px; background-color: #F3F4F6; border-radius: 8px; margin-bottom: 15px;"> |
|
<h3 style="color: #4F46E5; margin-bottom: 10px;">Results</h3> |
|
<p style="color: #6B7280; font-size: 0.9rem;">Compare model outputs and analysis metrics</p> |
|
</div> |
|
""") |
|
|
|
with gr.Group(): |
|
omn_response = gr.Textbox( |
|
label="Open-Mistral-Nemo Response", |
|
elem_classes="output-style" |
|
) |
|
ml_response = gr.Textbox( |
|
label="Mistral-Medium Response", |
|
elem_classes="output-style" |
|
) |
|
|
|
with gr.Group(): |
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.HTML("<h4>π Sentiment Analysis</h4>") |
|
large_sentiment = gr.Number(label="Mistral-Medium") |
|
open_sentiment = gr.Number(label="Open-Mistral-Nemo") |
|
|
|
with gr.Column(): |
|
gr.HTML("<h4>π Readability Scores</h4>") |
|
large_readability = gr.Number(label="Mistral-Medium") |
|
open_readability = gr.Number(label="Open-Mistral-Nemo") |
|
|
|
gr.HTML("<h4>π Key Terms</h4>") |
|
with gr.Row(): |
|
large_keywords = gr.Textbox(label="Mistral-Medium Keywords") |
|
open_keywords = gr.Textbox(label="Open-Mistral-Nemo Keywords") |
|
|
|
with gr.TabItem("π History & Analysis", id=1): |
|
with gr.Tabs(): |
|
with gr.TabItem("Prompt History"): |
|
history = gr.DataFrame( |
|
headers=["Timestamp", "Task", "System Prompt", "Input Text"], |
|
wrap=True, |
|
elem_classes="table-style" |
|
) |
|
|
|
with gr.TabItem("Model Responses"): |
|
responses = gr.DataFrame( |
|
headers=["Timestamp", "Open-Mistral-Nemo", "Mistral-Medium"], |
|
wrap=True, |
|
elem_classes="table-style" |
|
) |
|
|
|
with gr.TabItem("Analysis Results"): |
|
analysis = gr.DataFrame( |
|
headers=[ |
|
"Timestamp", |
|
"Open-Mistral-Nemo Sentiment", |
|
"Mistral-Medium Sentiment", |
|
"Open-Mistral-Nemo Keywords", |
|
"Mistral-Medium Keywords", |
|
"Open-Mistral-Nemo Readability", |
|
"Mistral-Medium Readability" |
|
], |
|
wrap=True, |
|
elem_classes="table-style" |
|
) |
|
|
|
with gr.TabItem("Model Parameters"): |
|
params = gr.DataFrame( |
|
headers=[ |
|
"Timestamp", |
|
"Temperature", |
|
"Top P", |
|
"Max Tokens", |
|
"Stop Sequences", |
|
"Random Seed", |
|
"Safe Prompt" |
|
], |
|
wrap=True, |
|
elem_classes="table-style" |
|
) |
|
|
|
|
|
gr.HTML(""" |
|
<div style="text-align: center; padding: 20px; margin-top: 30px; border-top: 1px solid #E5E7EB;"> |
|
<div style="margin-bottom: 20px;"> |
|
<h3 style="color: #4F46E5;">Built with Pixeltable</h3> |
|
<p style="color: #6B7280;">The unified data infrastructure for AI applications</p> |
|
</div> |
|
<div style="display: flex; justify-content: center; gap: 20px;"> |
|
<a href="https://github.com/pixeltable/pixeltable" target="_blank" |
|
style="color: #4F46E5; text-decoration: none;"> |
|
π Documentation |
|
</a> |
|
<a href="https://github.com/pixeltable/pixeltable" target="_blank" |
|
style="color: #4F46E5; text-decoration: none;"> |
|
π» GitHub |
|
</a> |
|
<a href="https://join.slack.com/t/pixeltablecommunity/shared_invite/zt-21fybjbn2-fZC_SJiuG6QL~Ai8T6VpFQ" target="_blank" |
|
style="color: #4F46E5; text-decoration: none;"> |
|
π¬ Community |
|
</a> |
|
</div> |
|
</div> |
|
""") |
|
|
|
|
|
gr.HTML(""" |
|
<style> |
|
.input-style { |
|
border: 1px solid #E5E7EB !important; |
|
border-radius: 8px !important; |
|
padding: 12px !important; |
|
} |
|
.output-style { |
|
background-color: #F9FAFB !important; |
|
border-radius: 8px !important; |
|
padding: 12px !important; |
|
} |
|
.table-style { |
|
border-collapse: collapse !important; |
|
width: 100% !important; |
|
} |
|
.table-style th { |
|
background-color: #F3F4F6 !important; |
|
padding: 12px !important; |
|
} |
|
.examples-style { |
|
margin: 20px 0; |
|
padding: 15px; |
|
border: 1px solid #E5E7EB; |
|
border-radius: 8px; |
|
background-color: white; |
|
} |
|
.examples-style .example-card { |
|
border: 1px solid #E5E7EB; |
|
border-radius: 6px; |
|
padding: 12px; |
|
margin-bottom: 10px; |
|
transition: all 0.2s; |
|
} |
|
.examples-style .example-card:hover { |
|
border-color: #4F46E5; |
|
box-shadow: 0 2px 4px rgba(0,0,0,0.1); |
|
} |
|
</style> |
|
""") |
|
|
|
submit_btn.click( |
|
run_inference_and_analysis, |
|
inputs=[ |
|
task, system_prompt, input_text, |
|
temperature, top_p, max_tokens, |
|
stop, random_seed, safe_prompt |
|
], |
|
outputs=[ |
|
omn_response, ml_response, |
|
large_sentiment, open_sentiment, |
|
large_keywords, open_keywords, |
|
large_readability, open_readability, |
|
history, responses, analysis, params |
|
] |
|
) |
|
|
|
return demo |
|
|
|
if __name__ == "__main__": |
|
gradio_interface().launch() |