PierreBrunelle commited on
Commit
cbf16d4
·
verified ·
1 Parent(s): d8e4619

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +330 -0
app.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pixeltable as pxt
3
+ from pixeltable.functions.mistralai import chat_completions
4
+ from datetime import datetime
5
+
6
+ from textblob import TextBlob
7
+ import re
8
+ import nltk
9
+ from nltk.tokenize import word_tokenize
10
+ from nltk.corpus import stopwords
11
+
12
+ # Ensure necessary NLTK data is downloaded
13
+ nltk.download('punkt', quiet=True)
14
+ nltk.download('stopwords', quiet=True)
15
+ nltk.download('punkt_tab', quiet=True)
16
+
17
+ import os
18
+ import getpass
19
+
20
+ # Set up Mistral API key
21
+ if 'MISTRAL_API_KEY' not in os.environ:
22
+ os.environ['MISTRAL_API_KEY'] = getpass.getpass('Mistral AI API Key:')
23
+
24
+ # Define UDFs
25
+ @pxt.udf
26
+ def get_sentiment_score(text: str) -> float:
27
+ return TextBlob(text).sentiment.polarity
28
+
29
+ @pxt.udf
30
+ def extract_keywords(text: str, num_keywords: int = 5) -> list:
31
+ stop_words = set(stopwords.words('english'))
32
+ words = word_tokenize(text.lower())
33
+ keywords = [word for word in words if word.isalnum() and word not in stop_words]
34
+ return sorted(set(keywords), key=keywords.count, reverse=True)[:num_keywords]
35
+
36
+ @pxt.udf
37
+ def calculate_readability(text: str) -> float:
38
+ words = len(re.findall(r'\w+', text))
39
+ sentences = len(re.findall(r'\w+[.!?]', text)) or 1
40
+ average_words_per_sentence = words / sentences
41
+ return 206.835 - 1.015 * average_words_per_sentence
42
+
43
+ # Function to run inference and analysis
44
+ def run_inference_and_analysis(task, system_prompt, input_text, temperature, top_p, max_tokens, min_tokens, stop, random_seed, safe_prompt):
45
+
46
+ # Initialize Pixeltable
47
+ pxt.drop_table('mistral_prompts', ignore_errors=True)
48
+ t = pxt.create_table('mistral_prompts', {
49
+ 'task': pxt.StringType(),
50
+ 'system': pxt.StringType(),
51
+ 'input_text': pxt.StringType(),
52
+ 'timestamp': pxt.TimestampType(),
53
+ 'temperature': pxt.FloatType(),
54
+ 'top_p': pxt.FloatType(),
55
+ 'max_tokens': pxt.IntType(nullable=True),
56
+ 'min_tokens': pxt.IntType(nullable=True),
57
+ 'stop': pxt.StringType(nullable=True),
58
+ 'random_seed': pxt.IntType(nullable=True),
59
+ 'safe_prompt': pxt.BoolType(nullable=True)
60
+ })
61
+
62
+ # Insert new row into Pixeltable
63
+ t.insert([{
64
+ 'task': task,
65
+ 'system': system_prompt,
66
+ 'input_text': input_text,
67
+ 'timestamp': datetime.now(),
68
+ 'temperature': temperature,
69
+ 'top_p': top_p,
70
+ 'max_tokens': max_tokens,
71
+ 'min_tokens': min_tokens,
72
+ 'stop': stop,
73
+ 'random_seed': random_seed,
74
+ 'safe_prompt': safe_prompt
75
+ }])
76
+
77
+ # Define messages for chat completion
78
+ msgs = [
79
+ {'role': 'system', 'content': t.system},
80
+ {'role': 'user', 'content': t.input_text}
81
+ ]
82
+
83
+ common_params = {
84
+ 'messages': msgs,
85
+ 'temperature': temperature,
86
+ 'top_p': top_p,
87
+ 'max_tokens': max_tokens if max_tokens is not None else 300,
88
+ 'min_tokens': min_tokens,
89
+ 'stop': stop.split(',') if stop else None,
90
+ 'random_seed': random_seed,
91
+ 'safe_prompt': safe_prompt
92
+ }
93
+
94
+ # Run inference with both models
95
+ t['open_mistral_nemo'] = chat_completions(model='open-mistral-nemo', **common_params)
96
+ t['mistral_medium'] = chat_completions(model='mistral-medium', **common_params)
97
+
98
+ # Extract responses
99
+ t['omn_response'] = t.open_mistral_nemo.choices[0].message.content.astype(pxt.StringType())
100
+ t['ml_response'] = t.mistral_medium.choices[0].message.content.astype(pxt.StringType())
101
+
102
+ # Run analysis
103
+ t['large_sentiment_score'] = get_sentiment_score(t.ml_response)
104
+ t['large_keywords'] = extract_keywords(t.ml_response)
105
+ t['large_readability_score'] = calculate_readability(t.ml_response)
106
+ t['open_sentiment_score'] = get_sentiment_score(t.omn_response)
107
+ t['open_keywords'] = extract_keywords(t.omn_response)
108
+ t['open_readability_score'] = calculate_readability(t.omn_response)
109
+
110
+ # Retrieve results
111
+ results = t.select(
112
+ t.omn_response, t.ml_response,
113
+ t.large_sentiment_score, t.open_sentiment_score,
114
+ t.large_keywords, t.open_keywords,
115
+ t.large_readability_score, t.open_readability_score
116
+ ).tail(1)
117
+
118
+ history = t.select(t.timestamp, t.task, t.system, t.input_text).order_by(t.timestamp, asc=False).collect().to_pandas()
119
+
120
+ responses = t.select(t.timestamp, t.omn_response, t.ml_response).order_by(t.timestamp, asc=False).collect().to_pandas()
121
+
122
+ analysis = t.select(
123
+ t.timestamp,
124
+ t.open_sentiment_score,
125
+ t.large_sentiment_score,
126
+ t.open_keywords,
127
+ t.large_keywords,
128
+ t.open_readability_score,
129
+ t.large_readability_score
130
+ ).order_by(t.timestamp, asc=False).collect().to_pandas()
131
+
132
+ params = t.select(
133
+ t.timestamp,
134
+ t.temperature,
135
+ t.top_p,
136
+ t.max_tokens,
137
+ t.min_tokens,
138
+ t.stop,
139
+ t.random_seed,
140
+ t.safe_prompt
141
+ ).order_by(t.timestamp, asc=False).collect().to_pandas()
142
+
143
+ return (
144
+ results['omn_response'][0],
145
+ results['ml_response'][0],
146
+ results['large_sentiment_score'][0],
147
+ results['open_sentiment_score'][0],
148
+ results['large_keywords'][0],
149
+ results['open_keywords'][0],
150
+ results['large_readability_score'][0],
151
+ results['open_readability_score'][0],
152
+ history,
153
+ responses,
154
+ analysis,
155
+ params
156
+ )
157
+
158
+ # Gradio interface
159
+ def gradio_interface():
160
+ with gr.Blocks(theme=gr.themes.Base(), title="Prompt Engineering and LLM Studio") as demo:
161
+ gr.Markdown(
162
+ """
163
+ <img src="https://raw.githubusercontent.com/pixeltable/pixeltable/main/docs/source/data/pixeltable-logo-large.png" alt="Pixeltable" width="20%" /></img>
164
+ # Prompt Engineering and LLM Studio
165
+
166
+ This application demonstrates how [Pixeltable](https://github.com/pixeltable/pixeltable) can be used for rapid and incremental prompt engineering
167
+ and model comparison workflows. It showcases Pixeltable's ability to directly store, version, index,
168
+ and transform data while providing an interactive interface to experiment with different prompts and models.
169
+
170
+ Remember, effective prompt engineering often requires experimentation and iteration. Use this tool to systematically improve your prompts and understand how different inputs and parameters affect the LLM outputs.
171
+ """
172
+ )
173
+
174
+ with gr.Row():
175
+ with gr.Column():
176
+ with gr.Accordion("What does it do?", open=False):
177
+ gr.Markdown(
178
+ """
179
+ 1. **Data Organization**: Pixeltable uses tables and views to organize data, similar to traditional databases but with enhanced capabilities for AI workflows.
180
+ 2. **Computed Columns**: These are dynamically generated columns based on expressions applied to columns.
181
+ 3. **Data Storage**: All prompts, responses, and analysis results are stored in Pixeltable tables.
182
+ 4. **Versioning**: Every operations are automatically versioned, allowing you to track changes over time.
183
+ 5. **UDFs**: Sentiment scores, keywords, and readability scores are computed dynamically.
184
+ 6. **Querying**: The history and analysis tabs leverage Pixeltable's querying capabilities to display results.
185
+ """
186
+ )
187
+
188
+ with gr.Column():
189
+ with gr.Accordion("How does it work?", open=False):
190
+ gr.Markdown(
191
+ """
192
+ 1. **Define your task**: This helps you keep track of different experiments.
193
+ 2. **Set up your prompt**: Enter a system prompt in the "System Prompt" field. Write your specific input or question in the "Input Text" field
194
+ 3. **Adjust parameters (optional)**: Adjust temperature, top_p, token limits, etc., to control the model's output.
195
+ 4. **Run the analysis**: Click the "Run Inference and Analysis" button.
196
+ 5. **Review the results**: Compare the responses from both models and exmaine the scores.
197
+ 6. **Iterate and refine**: Based on the results, refine your prompt or adjust parameters.
198
+ """
199
+ )
200
+
201
+ with gr.Row():
202
+ with gr.Column():
203
+ task = gr.Textbox(label="Task (Arbitrary Category)")
204
+ system_prompt = gr.Textbox(label="System Prompt")
205
+ input_text = gr.Textbox(label="Input Text")
206
+
207
+ with gr.Accordion("Advanced Settings", open=False):
208
+ temperature = gr.Slider(minimum=0, maximum=1, value=0.7, step=0.1, label="Temperature")
209
+ top_p = gr.Slider(minimum=0, maximum=1, value=0.9, step=0.1, label="Top P")
210
+ max_tokens = gr.Number(label="Max Tokens", value=300)
211
+ min_tokens = gr.Number(label="Min Tokens", value=None)
212
+ stop = gr.Textbox(label="Stop Sequences (comma-separated)")
213
+ random_seed = gr.Number(label="Random Seed", value=None)
214
+ safe_prompt = gr.Checkbox(label="Safe Prompt", value=False)
215
+
216
+ submit_btn = gr.Button("Run Inference and Analysis")
217
+
218
+ with gr.Tabs():
219
+ with gr.Tab("Prompt Intput"):
220
+ history = gr.Dataframe(
221
+ headers=["Task", "System Prompt", "Input Text", "Timestamp"],
222
+ wrap=True
223
+ )
224
+
225
+ with gr.Tab("Model Responses"):
226
+ responses = gr.Dataframe(
227
+ headers=["Timestamp", "Open-Mistral-Nemo Response", "Mistral-Medium Response"],
228
+ wrap=True
229
+ )
230
+
231
+ with gr.Tab("Analysis Results"):
232
+ analysis = gr.Dataframe(
233
+ headers=[
234
+ "Timestamp",
235
+ "Open-Mistral-Nemo Sentiment",
236
+ "Mistral-Medium Sentiment",
237
+ "Open-Mistral-Nemo Keywords",
238
+ "Mistral-Medium Keywords",
239
+ "Open-Mistral-Nemo Readability",
240
+ "Mistral-Medium Readability"
241
+ ],
242
+ wrap=True
243
+ )
244
+
245
+ with gr.Tab("Model Parameters"):
246
+ params = gr.Dataframe(
247
+ headers=[
248
+ "Timestamp",
249
+ "Temperature",
250
+ "Top P",
251
+ "Max Tokens",
252
+ "Min Tokens",
253
+ "Stop Sequences",
254
+ "Random Seed",
255
+ "Safe Prompt"
256
+ ],
257
+ wrap=True
258
+ )
259
+
260
+ # Define the examples
261
+ examples = [
262
+ # Example 1: Sentiment Analysis
263
+ ["Sentiment Analysis",
264
+ "You are an AI trained to analyze the sentiment of text. Provide a detailed analysis of the emotional tone, highlighting key phrases that indicate sentiment.",
265
+ "The new restaurant downtown exceeded all my expectations. The food was exquisite, the service impeccable, and the ambiance was perfect for a romantic evening. I can't wait to go back!",
266
+ 0.3, 0.95, 200, None, "", None, False],
267
+
268
+ # Example 2: Code Explanation
269
+ ["Code Explanation",
270
+ "You are an expert programmer. Explain the given code snippet in simple terms, highlighting its purpose, key components, and potential improvements.",
271
+ """
272
+ def quicksort(arr):
273
+ if len(arr) <= 1:
274
+ return arr
275
+ pivot = arr[len(arr) // 2]
276
+ left = [x for x in arr if x < pivot]
277
+ middle = [x for x in arr if x == pivot]
278
+ right = [x for x in arr if x > pivot]
279
+ return quicksort(left) + middle + quicksort(right)
280
+ """,
281
+ 0.7, 0.9, 400, 100, "In conclusion,", 42, True],
282
+
283
+ # Example 3: Creative Writing
284
+ ["Story Generation",
285
+ "You are a creative writer. Generate a short, engaging story based on the given prompt. Include vivid descriptions and an unexpected twist.",
286
+ "In a world where dreams are shared, a young girl discovers she can manipulate other people's dreams.",
287
+ 0.9, 0.8, 500, 200, "The end.", None, False]
288
+ ]
289
+
290
+ with gr.Column():
291
+ omn_response = gr.Textbox(label="Open-Mistral-Nemo Response")
292
+ ml_response = gr.Textbox(label="Mistral-Medium Response")
293
+
294
+ with gr.Row():
295
+ large_sentiment = gr.Number(label="Mistral-Medium Sentiment")
296
+ open_sentiment = gr.Number(label="Open-Mistral-Nemo Sentiment")
297
+
298
+ with gr.Row():
299
+ large_keywords = gr.Textbox(label="Mistral-Medium Keywords")
300
+ open_keywords = gr.Textbox(label="Open-Mistral-Nemo Keywords")
301
+
302
+ with gr.Row():
303
+ large_readability = gr.Number(label="Mistral-Medium Readability")
304
+ open_readability = gr.Number(label="Open-Mistral-Nemo Readability")
305
+
306
+ gr.Examples(
307
+ examples=examples,
308
+ inputs=[task, system_prompt, input_text, temperature, top_p, max_tokens, min_tokens, stop, random_seed, safe_prompt],
309
+ outputs=[omn_response, ml_response, large_sentiment, open_sentiment, large_keywords, open_keywords, large_readability, open_readability],
310
+ fn=run_inference_and_analysis,
311
+ cache_examples=True,
312
+ )
313
+
314
+ gr.Markdown(
315
+ """
316
+ For more information, visit [Pixeltable's GitHub repository](https://github.com/pixeltable/pixeltable).
317
+ """
318
+ )
319
+
320
+ submit_btn.click(
321
+ run_inference_and_analysis,
322
+ inputs=[task, system_prompt, input_text, temperature, top_p, max_tokens, min_tokens, stop, random_seed, safe_prompt],
323
+ outputs=[omn_response, ml_response, large_sentiment, open_sentiment, large_keywords, open_keywords, large_readability, open_readability, history, responses, analysis, params ]
324
+ )
325
+
326
+ return demo
327
+
328
+ # Launch the Gradio interface
329
+ if __name__ == "__main__":
330
+ gradio_interface().launch()