vishnupriyavr commited on
Commit
97b987c
·
1 Parent(s): 58aefe3

Update streamlit_utils.py

Browse files
Files changed (1) hide show
  1. streamlit_utils.py +16 -16
streamlit_utils.py CHANGED
@@ -50,21 +50,18 @@ def render_query():
50
  )
51
 
52
 
53
- @st.cache_data()
54
  def load_model():
55
  model_ckpt = "sentence-transformers/multi-qa-mpnet-base-dot-v1"
56
- tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
57
  model = TFAutoModel.from_pretrained(model_ckpt, from_pt=True)
58
 
59
- return tokenizer, model
60
 
61
 
62
- @st.cache_data()
63
  def load_peft_model():
64
  peft_model_base = AutoModelForSeq2SeqLM.from_pretrained(
65
  "google/flan-t5-small", torch_dtype=torch.bfloat16
66
  )
67
- peft_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
68
 
69
  peft_model = PeftModel.from_pretrained(
70
  peft_model_base,
@@ -72,24 +69,22 @@ def load_peft_model():
72
  torch_dtype=torch.bfloat16,
73
  is_trainable=False,
74
  )
75
- return peft_model, peft_tokenizer
76
 
77
 
78
- @st.cache_data()
79
  def load_faiss_dataset():
80
  faiss_dataset = load_dataset(
81
  "vishnupriyavr/wiki-movie-plots-with-summaries-faiss-embeddings",
82
  split="train",
83
  )
84
- faiss_dataset.set_format("pandas")
85
- df = faiss_dataset[:]
86
- plots_dataset = Dataset.from_pandas(df)
87
- plots_dataset.add_faiss_index(column="embeddings")
88
- return plots_dataset
89
 
90
 
91
  def get_embeddings(text_list):
92
- tokenizer, model = load_model()
 
 
93
  encoded_input = tokenizer(
94
  text_list, padding=True, truncation=True, return_tensors="tf"
95
  )
@@ -105,7 +100,11 @@ def cls_pooling(model_output):
105
  def search_movie(user_query, limit):
106
  question_embedding = get_embeddings([user_query]).numpy()
107
 
108
- plots_dataset = load_faiss_dataset()
 
 
 
 
109
  scores, samples = plots_dataset.get_nearest_examples(
110
  "embeddings", question_embedding, k=limit
111
  )
@@ -129,7 +128,8 @@ def search_movie(user_query, limit):
129
 
130
 
131
  def summarized_plot(sample_df, limit):
132
- peft_model, peft_tokenizer = load_peft_model()
 
133
  peft_model_text_output_list = []
134
 
135
  for i in range(limit):
@@ -169,4 +169,4 @@ def aggregate(items):
169
  result["title"] = group[0]["title"] # get titl from first item
170
  result["text"] = "\n\n".join([item["text"] for item in group])
171
  results.append(result)
172
- return results
 
50
  )
51
 
52
 
53
+ @st.cache_data(persist=True)
54
  def load_model():
55
  model_ckpt = "sentence-transformers/multi-qa-mpnet-base-dot-v1"
 
56
  model = TFAutoModel.from_pretrained(model_ckpt, from_pt=True)
57
 
58
+ return model
59
 
60
 
 
61
  def load_peft_model():
62
  peft_model_base = AutoModelForSeq2SeqLM.from_pretrained(
63
  "google/flan-t5-small", torch_dtype=torch.bfloat16
64
  )
 
65
 
66
  peft_model = PeftModel.from_pretrained(
67
  peft_model_base,
 
69
  torch_dtype=torch.bfloat16,
70
  is_trainable=False,
71
  )
72
+ return peft_model
73
 
74
 
75
+ @st.cache_data(persist=True)
76
  def load_faiss_dataset():
77
  faiss_dataset = load_dataset(
78
  "vishnupriyavr/wiki-movie-plots-with-summaries-faiss-embeddings",
79
  split="train",
80
  )
81
+ return faiss_dataset
 
 
 
 
82
 
83
 
84
  def get_embeddings(text_list):
85
+ model = load_model()
86
+ model_ckpt = "sentence-transformers/multi-qa-mpnet-base-dot-v1"
87
+ tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
88
  encoded_input = tokenizer(
89
  text_list, padding=True, truncation=True, return_tensors="tf"
90
  )
 
100
  def search_movie(user_query, limit):
101
  question_embedding = get_embeddings([user_query]).numpy()
102
 
103
+ faiss_dataset = load_faiss_dataset()
104
+ faiss_dataset.set_format("pandas")
105
+ df = faiss_dataset[:]
106
+ plots_dataset = Dataset.from_pandas(df)
107
+ plots_dataset.add_faiss_index(column="embeddings")
108
  scores, samples = plots_dataset.get_nearest_examples(
109
  "embeddings", question_embedding, k=limit
110
  )
 
128
 
129
 
130
  def summarized_plot(sample_df, limit):
131
+ peft_model = load_peft_model()
132
+ peft_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
133
  peft_model_text_output_list = []
134
 
135
  for i in range(limit):
 
169
  result["title"] = group[0]["title"] # get titl from first item
170
  result["text"] = "\n\n".join([item["text"] for item in group])
171
  results.append(result)
172
+ return results