Spaces:
Sleeping
Sleeping
Commit
·
97b987c
1
Parent(s):
58aefe3
Update streamlit_utils.py
Browse files- streamlit_utils.py +16 -16
streamlit_utils.py
CHANGED
@@ -50,21 +50,18 @@ def render_query():
|
|
50 |
)
|
51 |
|
52 |
|
53 |
-
@st.cache_data()
|
54 |
def load_model():
|
55 |
model_ckpt = "sentence-transformers/multi-qa-mpnet-base-dot-v1"
|
56 |
-
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
|
57 |
model = TFAutoModel.from_pretrained(model_ckpt, from_pt=True)
|
58 |
|
59 |
-
return
|
60 |
|
61 |
|
62 |
-
@st.cache_data()
|
63 |
def load_peft_model():
|
64 |
peft_model_base = AutoModelForSeq2SeqLM.from_pretrained(
|
65 |
"google/flan-t5-small", torch_dtype=torch.bfloat16
|
66 |
)
|
67 |
-
peft_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
|
68 |
|
69 |
peft_model = PeftModel.from_pretrained(
|
70 |
peft_model_base,
|
@@ -72,24 +69,22 @@ def load_peft_model():
|
|
72 |
torch_dtype=torch.bfloat16,
|
73 |
is_trainable=False,
|
74 |
)
|
75 |
-
return peft_model
|
76 |
|
77 |
|
78 |
-
@st.cache_data()
|
79 |
def load_faiss_dataset():
|
80 |
faiss_dataset = load_dataset(
|
81 |
"vishnupriyavr/wiki-movie-plots-with-summaries-faiss-embeddings",
|
82 |
split="train",
|
83 |
)
|
84 |
-
faiss_dataset
|
85 |
-
df = faiss_dataset[:]
|
86 |
-
plots_dataset = Dataset.from_pandas(df)
|
87 |
-
plots_dataset.add_faiss_index(column="embeddings")
|
88 |
-
return plots_dataset
|
89 |
|
90 |
|
91 |
def get_embeddings(text_list):
|
92 |
-
|
|
|
|
|
93 |
encoded_input = tokenizer(
|
94 |
text_list, padding=True, truncation=True, return_tensors="tf"
|
95 |
)
|
@@ -105,7 +100,11 @@ def cls_pooling(model_output):
|
|
105 |
def search_movie(user_query, limit):
|
106 |
question_embedding = get_embeddings([user_query]).numpy()
|
107 |
|
108 |
-
|
|
|
|
|
|
|
|
|
109 |
scores, samples = plots_dataset.get_nearest_examples(
|
110 |
"embeddings", question_embedding, k=limit
|
111 |
)
|
@@ -129,7 +128,8 @@ def search_movie(user_query, limit):
|
|
129 |
|
130 |
|
131 |
def summarized_plot(sample_df, limit):
|
132 |
-
peft_model
|
|
|
133 |
peft_model_text_output_list = []
|
134 |
|
135 |
for i in range(limit):
|
@@ -169,4 +169,4 @@ def aggregate(items):
|
|
169 |
result["title"] = group[0]["title"] # get titl from first item
|
170 |
result["text"] = "\n\n".join([item["text"] for item in group])
|
171 |
results.append(result)
|
172 |
-
return results
|
|
|
50 |
)
|
51 |
|
52 |
|
53 |
+
@st.cache_data(persist=True)
|
54 |
def load_model():
|
55 |
model_ckpt = "sentence-transformers/multi-qa-mpnet-base-dot-v1"
|
|
|
56 |
model = TFAutoModel.from_pretrained(model_ckpt, from_pt=True)
|
57 |
|
58 |
+
return model
|
59 |
|
60 |
|
|
|
61 |
def load_peft_model():
|
62 |
peft_model_base = AutoModelForSeq2SeqLM.from_pretrained(
|
63 |
"google/flan-t5-small", torch_dtype=torch.bfloat16
|
64 |
)
|
|
|
65 |
|
66 |
peft_model = PeftModel.from_pretrained(
|
67 |
peft_model_base,
|
|
|
69 |
torch_dtype=torch.bfloat16,
|
70 |
is_trainable=False,
|
71 |
)
|
72 |
+
return peft_model
|
73 |
|
74 |
|
75 |
+
@st.cache_data(persist=True)
|
76 |
def load_faiss_dataset():
|
77 |
faiss_dataset = load_dataset(
|
78 |
"vishnupriyavr/wiki-movie-plots-with-summaries-faiss-embeddings",
|
79 |
split="train",
|
80 |
)
|
81 |
+
return faiss_dataset
|
|
|
|
|
|
|
|
|
82 |
|
83 |
|
84 |
def get_embeddings(text_list):
|
85 |
+
model = load_model()
|
86 |
+
model_ckpt = "sentence-transformers/multi-qa-mpnet-base-dot-v1"
|
87 |
+
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
|
88 |
encoded_input = tokenizer(
|
89 |
text_list, padding=True, truncation=True, return_tensors="tf"
|
90 |
)
|
|
|
100 |
def search_movie(user_query, limit):
|
101 |
question_embedding = get_embeddings([user_query]).numpy()
|
102 |
|
103 |
+
faiss_dataset = load_faiss_dataset()
|
104 |
+
faiss_dataset.set_format("pandas")
|
105 |
+
df = faiss_dataset[:]
|
106 |
+
plots_dataset = Dataset.from_pandas(df)
|
107 |
+
plots_dataset.add_faiss_index(column="embeddings")
|
108 |
scores, samples = plots_dataset.get_nearest_examples(
|
109 |
"embeddings", question_embedding, k=limit
|
110 |
)
|
|
|
128 |
|
129 |
|
130 |
def summarized_plot(sample_df, limit):
|
131 |
+
peft_model = load_peft_model()
|
132 |
+
peft_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
|
133 |
peft_model_text_output_list = []
|
134 |
|
135 |
for i in range(limit):
|
|
|
169 |
result["title"] = group[0]["title"] # get titl from first item
|
170 |
result["text"] = "\n\n".join([item["text"] for item in group])
|
171 |
results.append(result)
|
172 |
+
return results
|