mgfrantz commited on
Commit
7441e5b
·
verified ·
1 Parent(s): b1df52f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +171 -0
app.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from datasets import load_dataset, Dataset
3
+ from llama_index.core import PromptTemplate
4
+ from llama_index.core.prompts import ChatMessage
5
+ from llama_index.llms.openai import OpenAI
6
+ from pydantic import BaseModel, Field
7
+ import asyncio
8
+ import numpy as np
9
+ import pandas as pd
10
+ from chromadb import Client
11
+ from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction
12
+ import structlog
13
+
14
+ logger = structlog.get_logger()
15
+
16
+ logger.info('Loading embedding model')
17
+ embed_fn = SentenceTransformerEmbeddingFunction('BAAI/bge-small-en-v1.5')
18
+
19
+
20
+ def load_train_data_and_vectorstore():
21
+ logger.info("Loading dataset")
22
+ ds = load_dataset('SetFit/amazon_reviews_multi_en')
23
+ train_samples_per_class = 50
24
+ eval_test_samples_per_class = 10
25
+ train = Dataset.from_pandas(ds['train'].to_pandas().groupby('label').sample(train_samples_per_class, random_state=1234).reset_index(drop=True))
26
+ reviews = Client().create_collection(
27
+ name='reviews',
28
+ embedding_function=embed_fn,
29
+ get_or_create=True
30
+ )
31
+ logger.info("Adding documents to vector store")
32
+ reviews.add(documents=train['text'], metadatas=[{'rating': x} for x in train['label']], ids=train['id'])
33
+ return train, reviews
34
+
35
+ train, reviews = load_train_data_and_vectorstore()
36
+
37
+ class Rating(BaseModel):
38
+ rating: int = Field(..., description="Rating of the review", enum=[0, 1, 2, 3, 4])
39
+
40
+ llm = OpenAI(model="gpt-4o-mini")
41
+ structured_llm = llm.as_structured_llm(Rating)
42
+
43
+ prompt_tmpl_str = """\
44
+ The review text is below.
45
+ ---------------------
46
+ {review}
47
+ ---------------------
48
+ Given the review text and not prior knowledge, \
49
+ please attempt to predict the score of the review.
50
+
51
+ Query: What is the rating of this review?
52
+ Answer: \
53
+ """
54
+
55
+ prompt_tmpl = PromptTemplate(
56
+ prompt_tmpl_str,
57
+ )
58
+
59
+ async def zero_shot_predict(text):
60
+ messages = [
61
+ ChatMessage.from_str(prompt_tmpl.format(review=text))
62
+ ]
63
+ response = await structured_llm.achat(messages)
64
+ return response.raw.rating
65
+
66
+ few_shot_prompt_tmpl_str = """\
67
+ The review text is below.
68
+ ---------------------
69
+ {review}
70
+ ---------------------
71
+ Given the review text and not prior knowledge, \
72
+ please attempt to predict the review score of the context. \
73
+ Here are several examples of reviews and their ratings:
74
+
75
+ {random_few_shot_examples}
76
+
77
+ Query: What is the rating of this review?
78
+ Answer: \
79
+ """
80
+
81
+ few_shot_prompt_tmpl = PromptTemplate(
82
+ few_shot_prompt_tmpl_str,
83
+ function_mappings={"random_few_shot_examples": random_few_shot_examples_fn},
84
+ )
85
+
86
+ rng = np.random.Generator(np.random.PCG64(1234))
87
+ def random_few_shot_examples_fn(**kwargs):
88
+ if n_samples:=kwargs.get('n_samples'):
89
+ random_examples = train.shuffle(generator=rng)[:n_samples]
90
+ else:
91
+ random_examples = train.shuffle(generator=rng)[:5]
92
+
93
+ result_strs = []
94
+ for text, rating in zip(random_examples['text'], random_examples['label']):
95
+ result_strs.append(f"Text: {text}\nRating: {rating}")
96
+ return "\n\n".join(result_strs)
97
+
98
+ async def random_few_shot_predict(text, n_examples=5):
99
+ tasks = []
100
+ for _ in range(3):
101
+ messages = [
102
+ ChatMessage.from_str(few_shot_prompt_tmpl.format(review=text, n_samples=n_examples))
103
+ ]
104
+ tasks.append(structured_llm.achat(messages, temperature=0.9))
105
+ results = await asyncio.gather(*tasks)
106
+ ratings = [r.raw.rating for r in results]
107
+ # print(ratings)
108
+ return pd.Series(ratings).mode()[0]
109
+
110
+ def dynamic_few_shot_examples_fn(**kwargs):
111
+ n_examples = kwargs.get('n_examples', 5)
112
+ retrievals = reviews.query(
113
+ query_texts=[kwargs['review']],
114
+ n_results=n_examples
115
+ )
116
+ result_strs = []
117
+ documents = retrievals['documents'][0]
118
+ metadatas = retrievals['metadatas'][0]
119
+ for document, metadata in zip(documents, metadatas):
120
+ result_strs.append(f"Text: {document}\nRating: {metadata.get('rating')}")
121
+ return "\n\n".join(result_strs)
122
+
123
+ dynamic_few_shot_prompt_tmpl_str = """\
124
+ The review text is below.
125
+ ---------------------
126
+ {review}
127
+ ---------------------
128
+ Given the review text and not prior knowledge, \
129
+ please attempt to predict the review score of the context. \
130
+ Here are several examples of reviews and their ratings:
131
+
132
+ {dynamic_few_shot_examples}
133
+
134
+ Query: What is the rating of this review?
135
+ Answer: \
136
+ """
137
+
138
+ dynamic_few_shot_prompt_tmpl = PromptTemplate(
139
+ dynamic_few_shot_prompt_tmpl_str,
140
+ function_mappings={"dynamic_few_shot_examples": dynamic_few_shot_examples_fn},
141
+ )
142
+
143
+ async def dynamic_few_shot_predict(text, n_examples=5):
144
+ messages = [
145
+ ChatMessage.from_str(dynamic_few_shot_prompt_tmpl.format(review=text, n_examples=n_examples))
146
+ ]
147
+ response = await dynamic_few_shot_structured_llm.achat(messages)
148
+ return response.raw.rating
149
+
150
+ def classify(review, num_examples, api_key):
151
+ llm = OpenAI(model="gpt-4o-mini", api_key=api_key).as_structured_llm(Rating)
152
+ zero_shot = asyncio.run(zero_shot_predict(review))
153
+ random_few_shot = asyncio.run(random_few_shot_predict(review, num_examples))
154
+ dynamic_few_shot = asyncio.run(dynamic_few_shot_predict(review, num_examples))
155
+ return zero_shot, random_few_shot, dynamic_few_shot
156
+
157
+ with gr.Blocks() as demo:
158
+ with gr.Row():
159
+ with gr.Column():
160
+ api_key = gr.Textbox(label='Openai API Key')
161
+ n_examples = gr.Slider(minimum=1, maximum=10, value=5, step=1, label='Number of examples to retrieve', interactive=True)
162
+ review = gr.Textbox(label='Review', interactive=True)
163
+ submit = gr.Button(value='Submit')
164
+ with gr.Column():
165
+ zero_shot_label = gr.Textbox(label='Zero shot', interactive=False)
166
+ random_few_shot_label = gr.Textbox(label='Random few shot', interactive=False)
167
+ dynamic_few_shot_label = gr.Textbox(label='Dynamic few shot', interactive=False)
168
+
169
+ submit.click(classify, [review, n_examples], [zero_shot_label, random_few_shot_label, dynamic_few_shot_label])
170
+
171
+ demo.queue().launch()