dejanseo commited on
Commit
1d3d42e
·
verified ·
1 Parent(s): 1596ad8

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +16 -37
src/streamlit_app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import time
2
  import torch
3
  import streamlit as st
@@ -22,16 +23,14 @@ GENERATION_CONFIG: Dict[str, Any] = {
22
  "repetition_penalty": 1.10, "seed": 42, "sort_by": "logp/len",
23
  }
24
 
25
- # ------------------ MODEL LOADING (LOCAL 4-BIT) ------------------
26
  @st.cache_resource
27
  def load_model() -> Tuple[MT5Tokenizer, MT5ForConditionalGeneration, torch.device]:
 
 
28
  tok = MT5Tokenizer.from_pretrained(MODEL_PATH, cache_dir=CACHE_DIR)
29
- model = MT5ForConditionalGeneration.from_pretrained(
30
- MODEL_PATH,
31
- cache_dir=CACHE_DIR,
32
- device_map={"": 0}
33
- )
34
- device = model.device
35
  return tok, model, device
36
 
37
  # ------------------ GENERATION HELPERS ------------------
@@ -44,8 +43,7 @@ def decode_sequences(tok: MT5Tokenizer, seqs: torch.Tensor) -> List[str]:
44
  return tok.batch_decode(seqs, skip_special_tokens=True)
45
 
46
  def avg_logprobs_from_generate(tok: MT5Tokenizer, gen) -> List[float]:
47
- if not hasattr(gen, "scores"):
48
- return [float("nan")] * gen.sequences.size(0)
49
  scores, seqs = gen.scores, gen.sequences
50
  nseq, eos_id, pad_id = seqs.size(0), tok.eos_token_id or 1, tok.pad_token_id
51
  sum_logp = torch.zeros(nseq, dtype=torch.float32, device=scores[0].device)
@@ -66,27 +64,15 @@ def avg_logprobs_from_generate(tok: MT5Tokenizer, gen) -> List[float]:
66
  # --- UPDATED sampling_generate function ---
67
  def sampling_generate(tok, model, device, inputs, top_n, temperature, top_p, no_repeat_ngram_size, repetition_penalty, bad_words_ids: List[List[int]] = None):
68
  """Now accepts a list of 'bad_words_ids' to forbid certain sequences."""
69
- kwargs = dict(
70
- max_length=MAX_TARGET_LENGTH,
71
- do_sample=True,
72
- temperature=temperature,
73
- top_p=top_p,
74
- num_return_sequences=top_n,
75
- return_dict_in_generate=True,
76
- output_scores=True
77
- )
78
- if no_repeat_ngram_size > 0:
79
- kwargs["no_repeat_ngram_size"] = int(no_repeat_ngram_size)
80
- if repetition_penalty != 1.0:
81
- kwargs["repetition_penalty"] = float(repetition_penalty)
82
- if bad_words_ids:
83
- kwargs["bad_words_ids"] = bad_words_ids
84
 
85
  gen = model.generate(**inputs, **kwargs)
86
  return decode_sequences(tok, gen.sequences), avg_logprobs_from_generate(tok, gen)
87
 
88
- def normalize_text(s: str) -> str:
89
- return " ".join(s.strip().lower().split())
90
 
91
  # ------------------ STREAMLIT APP ------------------
92
  st.set_page_config(
@@ -126,23 +112,16 @@ if run_button:
126
  seen_texts_for_bad_words = set()
127
 
128
  num_batches = (TOTAL_DESIRED_CANDIDATES + GENERATION_BATCH_SIZE - 1) // GENERATION_BATCH_SIZE
129
-
130
  progress_bar = st.progress(0)
131
 
132
  for i in range(num_batches):
133
  current_seed = cfg["seed"] + i
134
  torch.manual_seed(current_seed)
135
- if torch.cuda.is_available():
136
- torch.cuda.manual_seed_all(current_seed)
137
 
138
  bad_words_ids = None
139
  if seen_texts_for_bad_words:
140
- bad_words_ids = tok(
141
- list(seen_texts_for_bad_words),
142
- add_special_tokens=False,
143
- padding=True,
144
- truncation=True
145
- )["input_ids"]
146
 
147
  batch_texts, batch_scores = sampling_generate(
148
  tok, model, device, inputs,
@@ -157,11 +136,11 @@ if run_button:
157
  all_texts.extend(batch_texts)
158
  all_scores.extend(batch_scores)
159
  for txt in batch_texts:
160
- if txt:
161
- seen_texts_for_bad_words.add(txt)
162
 
163
  progress_bar.progress((i + 1) / num_batches)
164
 
 
165
  final_enriched = []
166
  final_seen_normalized = set()
167
  for txt, sc in zip(all_texts, all_scores):
 
1
+ # streamlit_app.py
2
  import time
3
  import torch
4
  import streamlit as st
 
23
  "repetition_penalty": 1.10, "seed": 42, "sort_by": "logp/len",
24
  }
25
 
26
+ # ------------------ MODEL LOADING (CPU/GPU AUTO) ------------------
27
  @st.cache_resource
28
  def load_model() -> Tuple[MT5Tokenizer, MT5ForConditionalGeneration, torch.device]:
29
+ # Avoid CUDA initialization if no driver; select device explicitly.
30
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
31
  tok = MT5Tokenizer.from_pretrained(MODEL_PATH, cache_dir=CACHE_DIR)
32
+ model = MT5ForConditionalGeneration.from_pretrained(MODEL_PATH, cache_dir=CACHE_DIR)
33
+ model.to(device)
 
 
 
 
34
  return tok, model, device
35
 
36
  # ------------------ GENERATION HELPERS ------------------
 
43
  return tok.batch_decode(seqs, skip_special_tokens=True)
44
 
45
  def avg_logprobs_from_generate(tok: MT5Tokenizer, gen) -> List[float]:
46
+ if not hasattr(gen, "scores"): return [float("nan")] * gen.sequences.size(0)
 
47
  scores, seqs = gen.scores, gen.sequences
48
  nseq, eos_id, pad_id = seqs.size(0), tok.eos_token_id or 1, tok.pad_token_id
49
  sum_logp = torch.zeros(nseq, dtype=torch.float32, device=scores[0].device)
 
64
  # --- UPDATED sampling_generate function ---
65
  def sampling_generate(tok, model, device, inputs, top_n, temperature, top_p, no_repeat_ngram_size, repetition_penalty, bad_words_ids: List[List[int]] = None):
66
  """Now accepts a list of 'bad_words_ids' to forbid certain sequences."""
67
+ kwargs = dict(max_length=MAX_TARGET_LENGTH, do_sample=True, temperature=temperature, top_p=top_p, num_return_sequences=top_n, return_dict_in_generate=True, output_scores=True)
68
+ if no_repeat_ngram_size > 0: kwargs["no_repeat_ngram_size"] = int(no_repeat_ngram_size)
69
+ if repetition_penalty != 1.0: kwargs["repetition_penalty"] = float(repetition_penalty)
70
+ if bad_words_ids: kwargs["bad_words_ids"] = bad_words_ids
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  gen = model.generate(**inputs, **kwargs)
73
  return decode_sequences(tok, gen.sequences), avg_logprobs_from_generate(tok, gen)
74
 
75
+ def normalize_text(s: str) -> str: return " ".join(s.strip().lower().split())
 
76
 
77
  # ------------------ STREAMLIT APP ------------------
78
  st.set_page_config(
 
112
  seen_texts_for_bad_words = set()
113
 
114
  num_batches = (TOTAL_DESIRED_CANDIDATES + GENERATION_BATCH_SIZE - 1) // GENERATION_BATCH_SIZE
 
115
  progress_bar = st.progress(0)
116
 
117
  for i in range(num_batches):
118
  current_seed = cfg["seed"] + i
119
  torch.manual_seed(current_seed)
120
+ if torch.cuda.is_available(): torch.cuda.manual_seed_all(current_seed)
 
121
 
122
  bad_words_ids = None
123
  if seen_texts_for_bad_words:
124
+ bad_words_ids = tok(list(seen_texts_for_bad_words), add_special_tokens=False, padding=True, truncation=True)["input_ids"]
 
 
 
 
 
125
 
126
  batch_texts, batch_scores = sampling_generate(
127
  tok, model, device, inputs,
 
136
  all_texts.extend(batch_texts)
137
  all_scores.extend(batch_scores)
138
  for txt in batch_texts:
139
+ if txt: seen_texts_for_bad_words.add(txt)
 
140
 
141
  progress_bar.progress((i + 1) / num_batches)
142
 
143
+ # Deduplicate and finalize the list
144
  final_enriched = []
145
  final_seen_normalized = set()
146
  for txt, sc in zip(all_texts, all_scores):