vishnupriyavr commited on
Commit
e279f9f
·
1 Parent(s): aebe8a5

Create streamlit_utils.py

Browse files
Files changed (1) hide show
  1. streamlit_utils.py +172 -0
streamlit_utils.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import (
3
+ AutoTokenizer,
4
+ TFAutoModel,
5
+ AutoModelForSeq2SeqLM,
6
+ GenerationConfig,
7
+ )
8
+ from datasets import Dataset
9
+ from datasets import load_dataset
10
+ import pandas as pd
11
+ from transformers import pipeline
12
+ from peft import PeftModel
13
+ import torch
14
+
15
+
16
+ def get_query():
17
+ if "suggestion" not in st.session_state:
18
+ st.session_state.suggestion = None
19
+ if "user_query" not in st.session_state:
20
+ st.session_state.user_query = ""
21
+ user_query = st.session_state.suggestion or st.session_state.user_query
22
+ st.session_state.suggestion = None
23
+ st.session_state.user_query = ""
24
+ return user_query
25
+
26
+
27
+ def render_suggestions():
28
+ def set_query(query):
29
+ st.session_state.suggestion = query
30
+
31
+ suggestions = [
32
+ "A girl who is cursed",
33
+ "A movie that talks about the importance of education",
34
+ "Story of a village head",
35
+ "A movie released in 2020s about mistaken identity",
36
+ "Estranged siblings meeting after long time",
37
+ ]
38
+ columns = st.columns(len(suggestions))
39
+ for i, column in enumerate(columns):
40
+ with column:
41
+ st.button(suggestions[i], on_click=set_query, args=[suggestions[i]])
42
+
43
+
44
+ def render_query():
45
+ st.text_input(
46
+ "Search",
47
+ placeholder="Search, e.g. 'A gangster story with a twist'",
48
+ key="user_query",
49
+ label_visibility="collapsed",
50
+ )
51
+
52
+
53
+ @st.cache_data()
54
+ def load_model():
55
+ model_ckpt = "sentence-transformers/multi-qa-mpnet-base-dot-v1"
56
+ tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
57
+ model = TFAutoModel.from_pretrained(model_ckpt, from_pt=True)
58
+
59
+ return tokenizer, model
60
+
61
+
62
+ @st.cache_data()
63
+ def load_peft_model():
64
+ peft_model_base = AutoModelForSeq2SeqLM.from_pretrained(
65
+ "google/flan-t5-small", torch_dtype=torch.bfloat16
66
+ )
67
+ peft_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
68
+
69
+ peft_model = PeftModel.from_pretrained(
70
+ peft_model_base,
71
+ "vishnupriyavr/flan-t5-movie-summary",
72
+ torch_dtype=torch.bfloat16,
73
+ is_trainable=False,
74
+ )
75
+ return peft_model, peft_tokenizer
76
+
77
+
78
+ @st.cache_data()
79
+ def load_faiss_dataset():
80
+ faiss_dataset = load_dataset(
81
+ "vishnupriyavr/wiki-movie-plots-with-summaries-faiss-embeddings",
82
+ split="train",
83
+ )
84
+ faiss_dataset.set_format("pandas")
85
+ df = faiss_dataset[:]
86
+ plots_dataset = Dataset.from_pandas(df)
87
+ plots_dataset.add_faiss_index(column="embeddings")
88
+ return plots_dataset
89
+
90
+
91
+ def get_embeddings(text_list):
92
+ tokenizer, model = load_model()
93
+ encoded_input = tokenizer(
94
+ text_list, padding=True, truncation=True, return_tensors="tf"
95
+ )
96
+ encoded_input = {k: v for k, v in encoded_input.items()}
97
+ model_output = model(**encoded_input)
98
+ return cls_pooling(model_output)
99
+
100
+
101
+ def cls_pooling(model_output):
102
+ return model_output.last_hidden_state[:, 0]
103
+
104
+
105
+ def search_movie(user_query, limit):
106
+ question_embedding = get_embeddings([user_query]).numpy()
107
+
108
+ plots_dataset = load_faiss_dataset()
109
+ scores, samples = plots_dataset.get_nearest_examples(
110
+ "embeddings", question_embedding, k=limit
111
+ )
112
+
113
+ samples_df = pd.DataFrame.from_dict(samples)
114
+ samples_df["scores"] = scores
115
+ samples_df.sort_values("scores", ascending=False, inplace=True)
116
+
117
+ samples_df.columns = [
118
+ "release_year",
119
+ "title",
120
+ "cast",
121
+ "wiki_page",
122
+ "plot",
123
+ "plot_length",
124
+ "text",
125
+ "scores",
126
+ "embeddings",
127
+ ]
128
+ return samples_df
129
+
130
+
131
+ def summarized_plot(sample_df, limit):
132
+ peft_model, peft_tokenizer = load_peft_model()
133
+ peft_model_text_output_list = []
134
+
135
+ for i in range(limit):
136
+ prompt = f"""
137
+ Summarize the following movie plot.
138
+
139
+ {sample_df.iloc[i]["plot"]}
140
+
141
+ Summary: """
142
+
143
+ input_ids = peft_tokenizer(prompt, return_tensors="pt").input_ids
144
+
145
+ peft_model_outputs = peft_model.generate(
146
+ input_ids=input_ids,
147
+ generation_config=GenerationConfig(
148
+ max_new_tokens=250, temperature=0.7, num_beams=1
149
+ ),
150
+ )
151
+ peft_model_text_output = peft_tokenizer.decode(
152
+ peft_model_outputs[0], skip_special_tokens=True
153
+ )
154
+ peft_model_text_output_list.append(peft_model_text_output)
155
+
156
+ return peft_model_text_output_list
157
+
158
+
159
+ def aggregate(items):
160
+ # group items by same url
161
+ groups = {}
162
+ for item in items:
163
+ groups.setdefault(item["url"], []).append(item)
164
+ # join text of same url
165
+ results = []
166
+ for group in groups.values():
167
+ result = {}
168
+ result["url"] = group[0]["url"] # get url from first item
169
+ result["title"] = group[0]["title"] # get titl from first item
170
+ result["text"] = "\n\n".join([item["text"] for item in group])
171
+ results.append(result)
172
+ return results