Spaces:
Running
on
Zero
Running
on
Zero
File size: 9,535 Bytes
6c63a2d b867be1 217c4d4 eb1a863 b8ee0a2 eb1a863 6c63a2d ddaff53 eb1a863 593a8e7 eb1a863 ddaff53 b867be1 97629be ddaff53 68e1082 1898f4f c5183c8 68e1082 8a142a6 fd247b7 8a142a6 ddaff53 eb1a863 b867be1 8a142a6 eb1a863 97629be 8a142a6 97629be 8a142a6 b867be1 8a142a6 ddaff53 eb1a863 798ebc4 eb1a863 798ebc4 8a142a6 798ebc4 eb1a863 798ebc4 eb1a863 798ebc4 eb1a863 798ebc4 8a142a6 798ebc4 eb1a863 8a142a6 eb1a863 bb6bbaf 97629be eb1a863 bb6bbaf 97629be eb1a863 bb6bbaf 8a142a6 ddaff53 eb1a863 b8ee0a2 97629be ddaff53 97629be b867be1 ddaff53 6b26b26 fd247b7 eb1a863 fd247b7 eb1a863 fd247b7 ddaff53 1db9e92 fd247b7 eb1a863 fd247b7 eb1a863 1db9e92 eb1a863 1db9e92 6b26b26 b8ee0a2 6b26b26 eb1a863 bb6bbaf 593a8e7 116a714 593a8e7 eb1a863 6b26b26 1db9e92 bb6bbaf eb1a863 217c4d4 eb1a863 2062515 eb1a863 2062515 eb1a863 2062515 b867be1 eb1a863 116a714 eb1a863 1898f4f 2062515 1898f4f eb1a863 1898f4f 97629be eb1a863 217c4d4 eb1a863 6b26b26 bb6bbaf 1db9e92 97629be eb1a863 97629be 6b26b26 1db9e92 97629be 1db9e92 97629be c1f1ebf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 |
import os
# Keep Dynamo error suppression
import torch._dynamo
torch._dynamo.config.suppress_errors = True
os.environ["MKL_THREADING_LAYER"] = "GNU"
import spaces
from peft import PeftModel
import traceback
import torch
from transformers import (
pipeline,
AutoTokenizer,
AutoModelForCausalLM,
StoppingCriteria,
BitNetForCausalLM
)
from .prompts import format_rag_prompt
# Remove interrupt import
# from .shared import generation_interrupt
models = {
"Qwen2.5-1.5b-Instruct": "qwen/qwen2.5-1.5b-instruct",
"Qwen2.5-3b-Instruct": "qwen/qwen2.5-3b-instruct",
"Llama-3.2-1b-Instruct": "meta-llama/llama-3.2-1b-instruct",
"Llama-3.2-3b-Instruct": "meta-llama/llama-3.2-3b-instruct",
"Gemma-3-1b-it": "google/gemma-3-1b-it",
"Gemma-3-4b-it": "google/gemma-3-4b-it",
"Gemma-2-2b-it": "google/gemma-2-2b-it",
"Phi-4-mini-instruct": "microsoft/phi-4-mini-instruct",
"Cogito-v1-preview-llama-3b": "deepcogito/cogito-v1-preview-llama-3b",
"IBM Granite-3.3-2b-instruct": "ibm-granite/granite-3.3-2b-instruct",
# "Bitnet-b1.58-2B4T": "microsoft/bitnet-b1.58-2B-4T",
# #"MiniCPM3-RAG-LoRA": "openbmb/MiniCPM3-RAG-LoRA",
"Qwen3-0.6b": "qwen/qwen3-0.6b",
"Qwen3-1.7b": "qwen/qwen3-1.7b",
"Qwen3-4b": "qwen/qwen3-4b",
"SmolLM2-1.7b-Instruct": "HuggingFaceTB/SmolLM2-1.7B-Instruct",
"EXAONE-3.5-2.4B-instruct": "LGAI-EXAONE/EXAONE-3.5-2.4B-Instruct",
"OLMo-2-1B-Instruct": "allenai/OLMo-2-0425-1B-Instruct",
"icecream-3b": "aizip-dev/icecream-3b",
}
tokenizer_cache = {}
# List of model names for easy access
model_names = list(models.keys())
# Remove interrupt criteria class since we're not using it
# class InterruptCriteria(StoppingCriteria):
# def __init__(self, interrupt_event):
# self.interrupt_event = interrupt_event
#
# def __call__(self, input_ids, scores, **kwargs):
# return self.interrupt_event.is_set()
@spaces.GPU
def generate_summaries(example, model_a_name, model_b_name):
"""
Generates summaries for the given example using the assigned models sequentially.
"""
# Remove interrupt checks
context_text = ""
context_parts = []
if "full_contexts" in example and example["full_contexts"]:
for i, ctx in enumerate(example["full_contexts"]):
content = ""
# Extract content from either dict or string
if isinstance(ctx, dict) and "content" in ctx:
content = ctx["content"]
elif isinstance(ctx, str):
content = ctx
# Add document number if not already present
if not content.strip().startswith("Document"):
content = f"Document {i + 1}:\n{content}"
context_parts.append(content)
context_text = "\n\n".join(context_parts)
else:
# Provide a graceful fallback instead of raising an error
print("Warning: No full context found in the example, using empty context")
context_text = ""
question = example.get("question", "")
print(f"Starting inference for Model A: {model_a_name}")
# Run model A
summary_a = run_inference(models[model_a_name], context_text, question)
print(f"Starting inference for Model B: {model_b_name}")
# Run model B
summary_b = run_inference(models[model_b_name], context_text, question)
print("Both models completed successfully")
return summary_a, summary_b
@spaces.GPU
def run_inference(model_name, context, question):
"""
Run inference using the specified model.
Returns the generated text.
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
result = ""
tokenizer_kwargs = {
"add_generation_prompt": True,
} # make sure qwen3 doesn't use thinking
generation_kwargs = {
"max_new_tokens": 512,
}
if "qwen3" in model_name.lower():
print(
f"Recognized {model_name} as a Qwen3 model. Setting enable_thinking=False."
)
tokenizer_kwargs["enable_thinking"] = False
try:
if model_name in tokenizer_cache:
tokenizer = tokenizer_cache[model_name]
else:
# Common arguments for tokenizer loading
tokenizer_load_args = {"padding_side": "left", "token": True}
actual_model_name_for_tokenizer = model_name
if "icecream" in model_name.lower():
actual_model_name_for_tokenizer = "meta-llama/llama-3.2-3b-instruct"
tokenizer = AutoTokenizer.from_pretrained(actual_model_name_for_tokenizer, **tokenizer_load_args)
tokenizer_cache[model_name] = tokenizer
accepts_sys = (
"System role not supported" not in tokenizer.chat_template
if tokenizer.chat_template
else False # Handle missing chat_template
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print("REACHED HERE BEFORE pipe")
print(f"Loading model {model_name}...")
if "bitnet" in model_name.lower():
bitnet_model = BitNetForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
)
pipe = pipeline(
"text-generation",
model=bitnet_model,
tokenizer=tokenizer,
torch_dtype=torch.bfloat16,
model_kwargs={
"attn_implementation": "eager",
},
)
elif "icecream" not in model_name.lower():
pipe = pipeline(
"text-generation",
model=model_name,
tokenizer=tokenizer,
device_map="cuda",
trust_remote_code=True,
torch_dtype=torch.bfloat16,
model_kwargs={
"attn_implementation": "eager",
},
)
else:
base_model = AutoModelForCausalLM.from_pretrained(
"meta-llama/llama-3.2-3b-instruct",
device_map="cuda",
torch_dtype=torch.bfloat16,
trust_remote_code=True,
)
model = PeftModel.from_pretrained(
base_model,
"aizip-dev/icecream-3b",
device_map="cuda",
torch_dtype=torch.bfloat16,
)
text_input = format_rag_prompt(question, context, accepts_sys)
print(f"Starting generation for {model_name}")
if "Gemma-3".lower() in model_name.lower():
print("REACHED HERE BEFORE GEN")
result = pipe(
text_input,
max_new_tokens=512,
generation_kwargs={"skip_special_tokens": True}
)[0]["generated_text"]
result = result[-1]["content"]
elif "icecream" in model_name.lower():
print("ICECREAM")
model_inputs = tokenizer.apply_chat_template(
text_input,
tokenize=True,
return_tensors="pt",
return_dict=True,
**tokenizer_kwargs,
)
model_inputs = model_inputs.to(model.device)
input_ids = model_inputs.input_ids
attention_mask = model_inputs.attention_mask
prompt_tokens_length = input_ids.shape[1]
with torch.inference_mode():
output_sequences = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=512,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id
)
generated_token_ids = output_sequences[0][prompt_tokens_length:]
result = tokenizer.decode(generated_token_ids, skip_special_tokens=True)
# elif "bitnet" in model_name.lower():
# formatted = tokenizer.apply_chat_template(
# text_input,
# tokenize=True,
# return_tensors="pt",
# return_dict=True,
# **tokenizer_kwargs,
# ).to(bitnet_model.device)
# with torch.inference_mode():
# output_sequences = bitnet_model.generate(
# **formatted,
# max_new_tokens=512,
# )
# result = tokenizer.decode(output_sequences[0][formatted['input_ids'].shape[-1]:], skip_special_tokens=True)
else: # For other models
formatted = pipe.tokenizer.apply_chat_template(
text_input,
tokenize=False,
**tokenizer_kwargs,
)
input_length = len(formatted)
outputs = pipe(
formatted,
max_new_tokens=512,
generation_kwargs={"skip_special_tokens": True}
)
result = outputs[0]["generated_text"][input_length:]
print(f"Generation completed for {model_name}")
except Exception as e:
print(f"Error in inference for {model_name}: {e}")
print(traceback.format_exc())
result = f"Error generating response: {str(e)[:200]}..."
finally:
# Clean up resources
if torch.cuda.is_available():
torch.cuda.empty_cache()
return result |