Spaces:
Running
Running
Update src/streamlit_app.py
Browse files- src/streamlit_app.py +16 -37
src/streamlit_app.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import time
|
2 |
import torch
|
3 |
import streamlit as st
|
@@ -22,16 +23,14 @@ GENERATION_CONFIG: Dict[str, Any] = {
|
|
22 |
"repetition_penalty": 1.10, "seed": 42, "sort_by": "logp/len",
|
23 |
}
|
24 |
|
25 |
-
# ------------------ MODEL LOADING (
|
26 |
@st.cache_resource
|
27 |
def load_model() -> Tuple[MT5Tokenizer, MT5ForConditionalGeneration, torch.device]:
|
|
|
|
|
28 |
tok = MT5Tokenizer.from_pretrained(MODEL_PATH, cache_dir=CACHE_DIR)
|
29 |
-
model = MT5ForConditionalGeneration.from_pretrained(
|
30 |
-
|
31 |
-
cache_dir=CACHE_DIR,
|
32 |
-
device_map={"": 0}
|
33 |
-
)
|
34 |
-
device = model.device
|
35 |
return tok, model, device
|
36 |
|
37 |
# ------------------ GENERATION HELPERS ------------------
|
@@ -44,8 +43,7 @@ def decode_sequences(tok: MT5Tokenizer, seqs: torch.Tensor) -> List[str]:
|
|
44 |
return tok.batch_decode(seqs, skip_special_tokens=True)
|
45 |
|
46 |
def avg_logprobs_from_generate(tok: MT5Tokenizer, gen) -> List[float]:
|
47 |
-
if not hasattr(gen, "scores"):
|
48 |
-
return [float("nan")] * gen.sequences.size(0)
|
49 |
scores, seqs = gen.scores, gen.sequences
|
50 |
nseq, eos_id, pad_id = seqs.size(0), tok.eos_token_id or 1, tok.pad_token_id
|
51 |
sum_logp = torch.zeros(nseq, dtype=torch.float32, device=scores[0].device)
|
@@ -66,27 +64,15 @@ def avg_logprobs_from_generate(tok: MT5Tokenizer, gen) -> List[float]:
|
|
66 |
# --- UPDATED sampling_generate function ---
|
67 |
def sampling_generate(tok, model, device, inputs, top_n, temperature, top_p, no_repeat_ngram_size, repetition_penalty, bad_words_ids: List[List[int]] = None):
|
68 |
"""Now accepts a list of 'bad_words_ids' to forbid certain sequences."""
|
69 |
-
kwargs = dict(
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
top_p=top_p,
|
74 |
-
num_return_sequences=top_n,
|
75 |
-
return_dict_in_generate=True,
|
76 |
-
output_scores=True
|
77 |
-
)
|
78 |
-
if no_repeat_ngram_size > 0:
|
79 |
-
kwargs["no_repeat_ngram_size"] = int(no_repeat_ngram_size)
|
80 |
-
if repetition_penalty != 1.0:
|
81 |
-
kwargs["repetition_penalty"] = float(repetition_penalty)
|
82 |
-
if bad_words_ids:
|
83 |
-
kwargs["bad_words_ids"] = bad_words_ids
|
84 |
|
85 |
gen = model.generate(**inputs, **kwargs)
|
86 |
return decode_sequences(tok, gen.sequences), avg_logprobs_from_generate(tok, gen)
|
87 |
|
88 |
-
def normalize_text(s: str) -> str:
|
89 |
-
return " ".join(s.strip().lower().split())
|
90 |
|
91 |
# ------------------ STREAMLIT APP ------------------
|
92 |
st.set_page_config(
|
@@ -126,23 +112,16 @@ if run_button:
|
|
126 |
seen_texts_for_bad_words = set()
|
127 |
|
128 |
num_batches = (TOTAL_DESIRED_CANDIDATES + GENERATION_BATCH_SIZE - 1) // GENERATION_BATCH_SIZE
|
129 |
-
|
130 |
progress_bar = st.progress(0)
|
131 |
|
132 |
for i in range(num_batches):
|
133 |
current_seed = cfg["seed"] + i
|
134 |
torch.manual_seed(current_seed)
|
135 |
-
if torch.cuda.is_available():
|
136 |
-
torch.cuda.manual_seed_all(current_seed)
|
137 |
|
138 |
bad_words_ids = None
|
139 |
if seen_texts_for_bad_words:
|
140 |
-
bad_words_ids = tok(
|
141 |
-
list(seen_texts_for_bad_words),
|
142 |
-
add_special_tokens=False,
|
143 |
-
padding=True,
|
144 |
-
truncation=True
|
145 |
-
)["input_ids"]
|
146 |
|
147 |
batch_texts, batch_scores = sampling_generate(
|
148 |
tok, model, device, inputs,
|
@@ -157,11 +136,11 @@ if run_button:
|
|
157 |
all_texts.extend(batch_texts)
|
158 |
all_scores.extend(batch_scores)
|
159 |
for txt in batch_texts:
|
160 |
-
if txt:
|
161 |
-
seen_texts_for_bad_words.add(txt)
|
162 |
|
163 |
progress_bar.progress((i + 1) / num_batches)
|
164 |
|
|
|
165 |
final_enriched = []
|
166 |
final_seen_normalized = set()
|
167 |
for txt, sc in zip(all_texts, all_scores):
|
|
|
1 |
+
# streamlit_app.py
|
2 |
import time
|
3 |
import torch
|
4 |
import streamlit as st
|
|
|
23 |
"repetition_penalty": 1.10, "seed": 42, "sort_by": "logp/len",
|
24 |
}
|
25 |
|
26 |
+
# ------------------ MODEL LOADING (CPU/GPU AUTO) ------------------
|
27 |
@st.cache_resource
|
28 |
def load_model() -> Tuple[MT5Tokenizer, MT5ForConditionalGeneration, torch.device]:
|
29 |
+
# Avoid CUDA initialization if no driver; select device explicitly.
|
30 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
31 |
tok = MT5Tokenizer.from_pretrained(MODEL_PATH, cache_dir=CACHE_DIR)
|
32 |
+
model = MT5ForConditionalGeneration.from_pretrained(MODEL_PATH, cache_dir=CACHE_DIR)
|
33 |
+
model.to(device)
|
|
|
|
|
|
|
|
|
34 |
return tok, model, device
|
35 |
|
36 |
# ------------------ GENERATION HELPERS ------------------
|
|
|
43 |
return tok.batch_decode(seqs, skip_special_tokens=True)
|
44 |
|
45 |
def avg_logprobs_from_generate(tok: MT5Tokenizer, gen) -> List[float]:
|
46 |
+
if not hasattr(gen, "scores"): return [float("nan")] * gen.sequences.size(0)
|
|
|
47 |
scores, seqs = gen.scores, gen.sequences
|
48 |
nseq, eos_id, pad_id = seqs.size(0), tok.eos_token_id or 1, tok.pad_token_id
|
49 |
sum_logp = torch.zeros(nseq, dtype=torch.float32, device=scores[0].device)
|
|
|
64 |
# --- UPDATED sampling_generate function ---
|
65 |
def sampling_generate(tok, model, device, inputs, top_n, temperature, top_p, no_repeat_ngram_size, repetition_penalty, bad_words_ids: List[List[int]] = None):
|
66 |
"""Now accepts a list of 'bad_words_ids' to forbid certain sequences."""
|
67 |
+
kwargs = dict(max_length=MAX_TARGET_LENGTH, do_sample=True, temperature=temperature, top_p=top_p, num_return_sequences=top_n, return_dict_in_generate=True, output_scores=True)
|
68 |
+
if no_repeat_ngram_size > 0: kwargs["no_repeat_ngram_size"] = int(no_repeat_ngram_size)
|
69 |
+
if repetition_penalty != 1.0: kwargs["repetition_penalty"] = float(repetition_penalty)
|
70 |
+
if bad_words_ids: kwargs["bad_words_ids"] = bad_words_ids
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
|
72 |
gen = model.generate(**inputs, **kwargs)
|
73 |
return decode_sequences(tok, gen.sequences), avg_logprobs_from_generate(tok, gen)
|
74 |
|
75 |
+
def normalize_text(s: str) -> str: return " ".join(s.strip().lower().split())
|
|
|
76 |
|
77 |
# ------------------ STREAMLIT APP ------------------
|
78 |
st.set_page_config(
|
|
|
112 |
seen_texts_for_bad_words = set()
|
113 |
|
114 |
num_batches = (TOTAL_DESIRED_CANDIDATES + GENERATION_BATCH_SIZE - 1) // GENERATION_BATCH_SIZE
|
|
|
115 |
progress_bar = st.progress(0)
|
116 |
|
117 |
for i in range(num_batches):
|
118 |
current_seed = cfg["seed"] + i
|
119 |
torch.manual_seed(current_seed)
|
120 |
+
if torch.cuda.is_available(): torch.cuda.manual_seed_all(current_seed)
|
|
|
121 |
|
122 |
bad_words_ids = None
|
123 |
if seen_texts_for_bad_words:
|
124 |
+
bad_words_ids = tok(list(seen_texts_for_bad_words), add_special_tokens=False, padding=True, truncation=True)["input_ids"]
|
|
|
|
|
|
|
|
|
|
|
125 |
|
126 |
batch_texts, batch_scores = sampling_generate(
|
127 |
tok, model, device, inputs,
|
|
|
136 |
all_texts.extend(batch_texts)
|
137 |
all_scores.extend(batch_scores)
|
138 |
for txt in batch_texts:
|
139 |
+
if txt: seen_texts_for_bad_words.add(txt)
|
|
|
140 |
|
141 |
progress_bar.progress((i + 1) / num_batches)
|
142 |
|
143 |
+
# Deduplicate and finalize the list
|
144 |
final_enriched = []
|
145 |
final_seen_normalized = set()
|
146 |
for txt, sc in zip(all_texts, all_scores):
|