Spaces:
Sleeping
Sleeping
File size: 4,867 Bytes
e279f9f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
import streamlit as st
from transformers import (
AutoTokenizer,
TFAutoModel,
AutoModelForSeq2SeqLM,
GenerationConfig,
)
from datasets import Dataset
from datasets import load_dataset
import pandas as pd
from transformers import pipeline
from peft import PeftModel
import torch
def get_query():
if "suggestion" not in st.session_state:
st.session_state.suggestion = None
if "user_query" not in st.session_state:
st.session_state.user_query = ""
user_query = st.session_state.suggestion or st.session_state.user_query
st.session_state.suggestion = None
st.session_state.user_query = ""
return user_query
def render_suggestions():
def set_query(query):
st.session_state.suggestion = query
suggestions = [
"A girl who is cursed",
"A movie that talks about the importance of education",
"Story of a village head",
"A movie released in 2020s about mistaken identity",
"Estranged siblings meeting after long time",
]
columns = st.columns(len(suggestions))
for i, column in enumerate(columns):
with column:
st.button(suggestions[i], on_click=set_query, args=[suggestions[i]])
def render_query():
st.text_input(
"Search",
placeholder="Search, e.g. 'A gangster story with a twist'",
key="user_query",
label_visibility="collapsed",
)
@st.cache_data()
def load_model():
model_ckpt = "sentence-transformers/multi-qa-mpnet-base-dot-v1"
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
model = TFAutoModel.from_pretrained(model_ckpt, from_pt=True)
return tokenizer, model
@st.cache_data()
def load_peft_model():
peft_model_base = AutoModelForSeq2SeqLM.from_pretrained(
"google/flan-t5-small", torch_dtype=torch.bfloat16
)
peft_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
peft_model = PeftModel.from_pretrained(
peft_model_base,
"vishnupriyavr/flan-t5-movie-summary",
torch_dtype=torch.bfloat16,
is_trainable=False,
)
return peft_model, peft_tokenizer
@st.cache_data()
def load_faiss_dataset():
faiss_dataset = load_dataset(
"vishnupriyavr/wiki-movie-plots-with-summaries-faiss-embeddings",
split="train",
)
faiss_dataset.set_format("pandas")
df = faiss_dataset[:]
plots_dataset = Dataset.from_pandas(df)
plots_dataset.add_faiss_index(column="embeddings")
return plots_dataset
def get_embeddings(text_list):
tokenizer, model = load_model()
encoded_input = tokenizer(
text_list, padding=True, truncation=True, return_tensors="tf"
)
encoded_input = {k: v for k, v in encoded_input.items()}
model_output = model(**encoded_input)
return cls_pooling(model_output)
def cls_pooling(model_output):
return model_output.last_hidden_state[:, 0]
def search_movie(user_query, limit):
question_embedding = get_embeddings([user_query]).numpy()
plots_dataset = load_faiss_dataset()
scores, samples = plots_dataset.get_nearest_examples(
"embeddings", question_embedding, k=limit
)
samples_df = pd.DataFrame.from_dict(samples)
samples_df["scores"] = scores
samples_df.sort_values("scores", ascending=False, inplace=True)
samples_df.columns = [
"release_year",
"title",
"cast",
"wiki_page",
"plot",
"plot_length",
"text",
"scores",
"embeddings",
]
return samples_df
def summarized_plot(sample_df, limit):
peft_model, peft_tokenizer = load_peft_model()
peft_model_text_output_list = []
for i in range(limit):
prompt = f"""
Summarize the following movie plot.
{sample_df.iloc[i]["plot"]}
Summary: """
input_ids = peft_tokenizer(prompt, return_tensors="pt").input_ids
peft_model_outputs = peft_model.generate(
input_ids=input_ids,
generation_config=GenerationConfig(
max_new_tokens=250, temperature=0.7, num_beams=1
),
)
peft_model_text_output = peft_tokenizer.decode(
peft_model_outputs[0], skip_special_tokens=True
)
peft_model_text_output_list.append(peft_model_text_output)
return peft_model_text_output_list
def aggregate(items):
# group items by same url
groups = {}
for item in items:
groups.setdefault(item["url"], []).append(item)
# join text of same url
results = []
for group in groups.values():
result = {}
result["url"] = group[0]["url"] # get url from first item
result["title"] = group[0]["title"] # get titl from first item
result["text"] = "\n\n".join([item["text"] for item in group])
results.append(result)
return results
|