Spaces:
Sleeping
Sleeping
#!/usr/bin/env python3 | |
""" | |
AR-Diffusion Chat Interface for Hugging Face Spaces | |
Experimental model with Quality vs Speed modes | |
Optimized for Zero GPU deployment with @spaces.GPU | |
""" | |
import gradio as gr | |
import torch | |
import torch.nn.functional as F | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import random | |
import numpy as np | |
import re | |
import time | |
from typing import List, Tuple, Generator | |
import os | |
import gc | |
import spaces | |
# Global model variables for memory efficiency | |
tokenizer = None | |
model = None | |
current_generator = None | |
device = None | |
def get_noising_schedule(i, max_it, sharpness=5.0): | |
"""Exponential noise schedule for denoising""" | |
x = i / max_it | |
return (np.exp(-sharpness * x) - np.exp(-sharpness)) / (1 - np.exp(-sharpness)) | |
class ARDiffusionGenerator: | |
"""Base AR-Diffusion generator with shared functionality""" | |
def __init__(self, tokenizer, model, device): | |
self.tokenizer = tokenizer | |
self.model = model | |
self.device = device | |
self.mask_token_id = self._find_mask_token() | |
def _find_mask_token(self) -> int: | |
"""Find MASK token ID""" | |
for candidate in ['MASK', '<mask>', '[MASK]', '<|mask|>']: | |
try: | |
tokens = self.tokenizer.encode(candidate, add_special_tokens=False) | |
if len(tokens) == 1: | |
return tokens[0] | |
except: | |
continue | |
return getattr(self.tokenizer, 'unk_token_id', 50257) or 50257 | |
def create_prompt(self, instruction: str) -> str: | |
"""Create Alpaca-style prompt""" | |
return f"""### Instruction: | |
{instruction} | |
### Response: | |
""" | |
class QualityGenerator(ARDiffusionGenerator): | |
"""Quality-focused AR-Diffusion generator (from first script)""" | |
def filter_logits(self, logits: torch.Tensor, top_k: int = 0, top_p: float = 1.0, | |
temperature: float = 1.0) -> torch.Tensor: | |
"""Research-grade filtering with proper order""" | |
original_shape = logits.shape | |
if logits.dim() == 3: | |
logits = logits.squeeze(0) | |
elif logits.dim() == 1: | |
logits = logits.unsqueeze(0) | |
logits = logits.clone() | |
# Temperature scaling first | |
if temperature != 1.0: | |
logits = logits / temperature | |
# Top-k filtering | |
if top_k > 0 and top_k < logits.size(-1): | |
topk_vals, _ = torch.topk(logits, top_k, dim=-1) | |
thresholds = topk_vals[:, -1].unsqueeze(-1) | |
logits = torch.where(logits < thresholds, | |
torch.full_like(logits, float("-inf")), logits) | |
# Top-p filtering | |
if top_p > 0.0 and top_p < 1.0: | |
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) | |
probs = torch.softmax(sorted_logits, dim=-1) | |
cum_probs = probs.cumsum(dim=-1) | |
mask = cum_probs > top_p | |
mask[:, 0] = False | |
scatter_mask = torch.zeros_like(logits, dtype=torch.bool).scatter( | |
dim=-1, index=sorted_indices, src=mask) | |
logits = torch.where(scatter_mask, | |
torch.full_like(logits, float("-inf")), logits) | |
# Restore original shape | |
if len(original_shape) == 1: | |
logits = logits.squeeze(0) | |
elif original_shape[0] == 1 and logits.dim() == 2: | |
logits = logits.unsqueeze(0) | |
return logits | |
def generate_start(self, prompt: str, length: int = 8) -> List[int]: | |
"""Generate natural start""" | |
tokens = self.tokenizer(prompt, return_tensors="pt").to(self.device) | |
input_ids = tokens['input_ids'][0] | |
generated = [] | |
current = input_ids.clone() | |
with torch.no_grad(): | |
for _ in range(length): | |
outputs = self.model(input_ids=current.unsqueeze(0)) | |
logits = outputs.logits[0, -1] | |
filtered_logits = self.filter_logits( | |
logits, top_k=50, top_p=0.9, temperature=0.8 | |
) | |
probs = F.softmax(filtered_logits, dim=-1) | |
next_token = torch.multinomial(probs, 1).item() | |
if next_token in [self.tokenizer.eos_token_id, 128001, 13]: | |
break | |
generated.append(next_token) | |
current = torch.cat([current, torch.tensor([next_token], device=self.device)]) | |
return generated | |
def create_sequence(self, prompt: str) -> Tuple[str, torch.Tensor]: | |
"""Create corrupted sequence for quality mode""" | |
prompt_tokens = self.tokenizer(prompt, return_tensors="pt")['input_ids'][0] | |
natural_start = self.generate_start(prompt, length=random.randint(8, 12)) | |
# Longer sequences for better quality | |
prompt_length = len(prompt_tokens) | |
if prompt_length > 25: | |
num_masks = random.randint(35, 50) | |
elif prompt_length > 15: | |
num_masks = random.randint(25, 40) | |
else: | |
num_masks = random.randint(20, 35) | |
sequence = ( | |
prompt_tokens.tolist() + | |
natural_start + | |
[self.mask_token_id] * num_masks + | |
[13] | |
) | |
tensor = torch.tensor(sequence) | |
text = self.tokenizer.decode(tensor, skip_special_tokens=False) | |
return text, tensor | |
def generate(self, prompt: str, progress_callback=None) -> Tuple[str, dict]: | |
"""Quality generation with progress updates and speed tracking""" | |
steps = 40 | |
temperature = 0.7 | |
start_time = time.time() | |
if progress_callback: | |
progress_callback(0.1, "Creating sequence...") | |
full_prompt = self.create_prompt(prompt) | |
corrupted_text, corrupted_ids = self.create_sequence(full_prompt) | |
if progress_callback: | |
progress_callback(0.2, "Starting quality denoising...") | |
result, stats = self._denoise_quality(corrupted_ids, steps, temperature, progress_callback) | |
# Calculate overall stats | |
total_time = time.time() - start_time | |
response = self._clean_response(result) | |
word_count = len(response.split()) | |
stats.update({ | |
'total_time': total_time, | |
'word_count': word_count, | |
'words_per_second': word_count / total_time if total_time > 0 else 0 | |
}) | |
return response, stats | |
def _denoise_quality(self, corrupted_ids: torch.Tensor, steps: int, temperature: float, progress_callback=None) -> Tuple[str, dict]: | |
"""Quality denoising with progress updates and speed tracking""" | |
current_ids = corrupted_ids.clone() | |
total_replacements = 0 | |
start_time = time.time() | |
for step in range(steps): | |
step_start = time.time() | |
if progress_callback: | |
progress = 0.2 + (step / steps) * 0.7 | |
elapsed = time.time() - start_time | |
tokens_per_sec = total_replacements / elapsed if elapsed > 0 else 0 | |
progress_callback(progress, f"Quality step {step+1}/{steps} | {tokens_per_sec:.1f} tok/s") | |
mask_positions = (current_ids == self.mask_token_id).nonzero(as_tuple=True)[0] | |
if len(mask_positions) == 0: | |
break | |
with torch.no_grad(): | |
outputs = self.model(input_ids=current_ids.unsqueeze(0).to(self.device)) | |
logits = outputs.logits[0] | |
current_temp = max(0.4, temperature * (1 - step / steps)) | |
# Conservative replacement for quality | |
if step < steps // 4: | |
max_replacements = min(1, len(mask_positions)) | |
elif step < steps // 2: | |
max_replacements = min(2, len(mask_positions)) | |
else: | |
max_replacements = min(3, len(mask_positions)) | |
sorted_positions = sorted(mask_positions.tolist()) | |
step_replacements = 0 | |
for pos in sorted_positions[:max_replacements]: | |
if pos < len(logits): | |
token_logits = logits[pos].clone() | |
# Anti-repetition | |
context_start = max(0, pos - 5) | |
recent_tokens = set(current_ids[context_start:pos].tolist()) | |
for recent_token in recent_tokens: | |
if recent_token < len(token_logits): | |
token_logits[recent_token] -= 8.0 | |
# Quality filtering | |
filtered_logits = self.filter_logits( | |
token_logits, | |
top_k=30, | |
top_p=0.75, | |
temperature=current_temp | |
) | |
probs = F.softmax(filtered_logits, dim=-1) | |
probs = torch.clamp(probs, min=1e-8, max=1.0) | |
new_token = torch.multinomial(probs, 1).item() | |
# Filter unwanted tokens | |
unwanted = [self.mask_token_id, 128001, 128000] | |
if new_token in unwanted: | |
top_k_vals, top_k_indices = torch.topk(filtered_logits, 10) | |
for alternative in top_k_indices: | |
if alternative.item() not in unwanted: | |
new_token = alternative.item() | |
break | |
current_ids[pos] = new_token | |
step_replacements += 1 | |
total_replacements += 1 | |
if progress_callback: | |
elapsed = time.time() - start_time | |
final_speed = total_replacements / elapsed if elapsed > 0 else 0 | |
progress_callback(0.95, f"Finalizing... | Final speed: {final_speed:.1f} tok/s") | |
# Calculate final statistics | |
total_time = time.time() - start_time | |
stats = { | |
'mode': 'Quality', | |
'steps': steps, | |
'tokens_replaced': total_replacements, | |
'generation_time': total_time, | |
'tokens_per_second': total_replacements / total_time if total_time > 0 else 0 | |
} | |
result = self.tokenizer.decode(current_ids, skip_special_tokens=True) | |
return result, stats | |
def _clean_response(self, text: str) -> str: | |
"""Clean response for quality output""" | |
if "### Response:" in text: | |
response = text.split("### Response:")[-1].strip() | |
else: | |
response = text.strip() | |
if not response: | |
return text | |
# Quality cleaning | |
response = re.sub(r"'{2,}", "", response) | |
response = re.sub(r'"{2,}', "", response) | |
response = re.sub(r"\.{2,}", ".", response) | |
response = re.sub(r",{2,}", ",", response) | |
response = re.sub(r"\s+", " ", response) | |
# Remove artifacts | |
response = re.sub(r"\$+", "", response) | |
response = re.sub(r"#+", "", response) | |
response = re.sub(r"@+", "", response) | |
response = response.strip() | |
if response and not response.endswith(('.', '!', '?')): | |
response += "." | |
return response | |
class SpeedGenerator(ARDiffusionGenerator): | |
"""Speed-focused AR-Diffusion generator (from second script)""" | |
def filter_logits(self, logits: torch.Tensor, top_k: int = 15, top_p: float = 0.8, | |
temperature: float = 1.0) -> torch.Tensor: | |
"""Fast logits filtering""" | |
logits = logits.clone() | |
if temperature != 1.0: | |
logits = logits / temperature | |
# Top-k filtering | |
if top_k > 0 and top_k < logits.size(-1): | |
topk_vals, _ = torch.topk(logits, top_k, dim=-1) | |
threshold = topk_vals[-1] | |
logits = torch.where(logits < threshold, torch.full_like(logits, float("-inf")), logits) | |
# Top-p filtering | |
if top_p > 0.0 and top_p < 1.0: | |
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) | |
probs = torch.softmax(sorted_logits, dim=-1) | |
cum_probs = probs.cumsum(dim=-1) | |
mask = cum_probs > top_p | |
mask[0] = False | |
scatter_mask = torch.zeros_like(logits, dtype=torch.bool) | |
scatter_mask.scatter_(0, sorted_indices, mask) | |
logits = torch.where(scatter_mask, torch.full_like(logits, float("-inf")), logits) | |
return logits | |
def generate_start(self, prompt: str, length: int = 6) -> List[int]: | |
"""Generate natural start for speed mode""" | |
tokens = self.tokenizer(prompt, return_tensors="pt").to(self.device) | |
input_ids = tokens['input_ids'][0] | |
generated = [] | |
current = input_ids.clone() | |
with torch.no_grad(): | |
for _ in range(length): | |
outputs = self.model(input_ids=current.unsqueeze(0)) | |
logits = outputs.logits[0, -1] | |
filtered_logits = self.filter_logits(logits, top_k=20, top_p=0.9, temperature=0.8) | |
probs = F.softmax(filtered_logits, dim=-1) | |
next_token = torch.multinomial(probs, 1).item() | |
if next_token in [self.tokenizer.eos_token_id, 128001, 13]: | |
break | |
generated.append(next_token) | |
current = torch.cat([current, torch.tensor([next_token], device=self.device)]) | |
return generated | |
def create_sequence(self, prompt: str) -> Tuple[str, torch.Tensor]: | |
"""Create sequence optimized for speed""" | |
prompt_tokens = self.tokenizer(prompt, return_tensors="pt")['input_ids'][0] | |
natural_start = self.generate_start(prompt, length=6) | |
# Shorter sequences for speed | |
prompt_words = len(prompt.split()) | |
if prompt_words > 8: | |
num_masks = random.randint(15, 25) | |
else: | |
num_masks = random.randint(12, 20) | |
sequence = ( | |
prompt_tokens.tolist() + | |
natural_start + | |
[self.mask_token_id] * num_masks + | |
[13] | |
) | |
tensor = torch.tensor(sequence) | |
text = self.tokenizer.decode(tensor, skip_special_tokens=False) | |
return text, tensor | |
def generate(self, prompt: str, progress_callback=None) -> Tuple[str, dict]: | |
"""Speed generation with progress updates and speed tracking""" | |
steps = 10 | |
temperature = 0.8 | |
start_time = time.time() | |
if progress_callback: | |
progress_callback(0.1, "Creating sequence...") | |
full_prompt = self.create_prompt(prompt) | |
corrupted_text, corrupted_ids = self.create_sequence(full_prompt) | |
if progress_callback: | |
progress_callback(0.2, "Starting speed denoising...") | |
result, stats = self._denoise_speed(corrupted_ids, steps, temperature, progress_callback) | |
# Calculate overall stats | |
total_time = time.time() - start_time | |
response = self._clean_response(result) | |
word_count = len(response.split()) | |
stats.update({ | |
'total_time': total_time, | |
'word_count': word_count, | |
'words_per_second': word_count / total_time if total_time > 0 else 0 | |
}) | |
return response, stats | |
def _denoise_speed(self, corrupted_ids: torch.Tensor, steps: int, temperature: float, progress_callback=None) -> Tuple[str, dict]: | |
"""Ultra-fast denoising with progress updates and speed tracking""" | |
current_ids = corrupted_ids.clone() | |
total_replacements = 0 | |
start_time = time.time() | |
# Use mixed precision for speed on GPU | |
with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=self.device.type == 'cuda'): | |
for step in range(steps): | |
step_start = time.time() | |
if progress_callback: | |
progress = 0.2 + (step / steps) * 0.7 | |
elapsed = time.time() - start_time | |
tokens_per_sec = total_replacements / elapsed if elapsed > 0 else 0 | |
progress_callback(progress, f"Speed step {step+1}/{steps} | {tokens_per_sec:.1f} tok/s") | |
mask_pos = (current_ids == self.mask_token_id).nonzero(as_tuple=True)[0] | |
if len(mask_pos) == 0: | |
break | |
with torch.no_grad(): | |
outputs = self.model(input_ids=current_ids.unsqueeze(0).to(self.device)) | |
logits = outputs.logits[0] | |
current_temp = temperature * (0.9 + 0.2 * (step / steps)) | |
# Aggressive replacement for speed | |
max_replace = min(8, len(mask_pos)) | |
positions = sorted(mask_pos.tolist())[:max_replace] | |
step_replacements = 0 | |
for pos in positions: | |
if pos < len(logits): | |
token_logits = logits[pos].clone() | |
# Light anti-repetition | |
recent_start = max(0, pos - 3) | |
recent_tokens = set(current_ids[recent_start:pos].tolist()) | |
for token in recent_tokens: | |
if token < len(token_logits): | |
token_logits[token] -= 3.0 | |
# Fast filtering | |
filtered_logits = self.filter_logits( | |
token_logits, top_k=12, top_p=0.85, temperature=current_temp | |
) | |
probs = F.softmax(filtered_logits, dim=-1) | |
probs = torch.clamp(probs, min=1e-8, max=1.0) | |
new_token = torch.multinomial(probs, 1).item() | |
# Quick filtering | |
if new_token in [self.mask_token_id, 128001, 128000]: | |
top_vals, top_indices = torch.topk(filtered_logits, 3) | |
new_token = top_indices[1].item() | |
current_ids[pos] = new_token | |
step_replacements += 1 | |
total_replacements += 1 | |
if progress_callback: | |
elapsed = time.time() - start_time | |
final_speed = total_replacements / elapsed if elapsed > 0 else 0 | |
progress_callback(0.95, f"Finalizing... | Final speed: {final_speed:.1f} tok/s") | |
# Calculate final statistics | |
total_time = time.time() - start_time | |
stats = { | |
'mode': 'Speed', | |
'steps': steps, | |
'tokens_replaced': total_replacements, | |
'generation_time': total_time, | |
'tokens_per_second': total_replacements / total_time if total_time > 0 else 0 | |
} | |
result = self.tokenizer.decode(current_ids, skip_special_tokens=True) | |
return result, stats | |
def _clean_response(self, text: str) -> str: | |
"""Clean response for speed output""" | |
if "### Response:" in text: | |
response = text.split("### Response:")[-1].strip() | |
else: | |
response = text.strip() | |
if not response: | |
return text | |
# Minimal cleaning for speed | |
response = re.sub(r"'{3,}", "", response) | |
response = re.sub(r'"{3,}', "", response) | |
response = re.sub(r"\.{3,}", ".", response) | |
response = re.sub(r",{3,}", ",", response) | |
response = re.sub(r"\s+", " ", response) | |
response = response.strip() | |
if response and not response.endswith(('.', '!', '?')): | |
response += "." | |
return response | |
def load_model(): | |
"""Load model with Zero GPU optimization using @spaces.GPU""" | |
global tokenizer, model, device | |
if tokenizer is not None and model is not None: | |
return tokenizer, model, device | |
model_path = "rootxhacker/llama-3B-diffusion-exp-fixed" | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print(f"Loading model on {device}...") | |
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
model = AutoModelForCausalLM.from_pretrained( | |
model_path, | |
torch_dtype=torch.float16 if device.type == "cuda" else torch.float32, | |
device_map="auto" if device.type == "cuda" else None, | |
trust_remote_code=True, | |
low_cpu_mem_usage=True | |
) | |
return tokenizer, model, device | |
def cleanup_memory(): | |
"""Clean up GPU memory""" | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
gc.collect() | |
def chat_function(message, history, mode, progress=gr.Progress()): | |
"""Main chat function with @spaces.GPU decorator, progress tracking, and speed display""" | |
if not message.strip(): | |
return history, "", "" | |
try: | |
# Load model (this will run on GPU when GPU is allocated) | |
progress(0.05, description="Loading model on GPU...") | |
tok, mod, dev = load_model() | |
# Create appropriate generator | |
if mode == "Quality (Slower, Better)": | |
generator = QualityGenerator(tok, mod, dev) | |
progress(0.1, description="Initializing quality mode...") | |
else: | |
generator = SpeedGenerator(tok, mod, dev) | |
progress(0.1, description="Initializing speed mode...") | |
# Generate response with progress callback | |
def progress_callback(pct, desc): | |
progress(pct, description=desc) | |
response, stats = generator.generate(message, progress_callback) | |
progress(1.0, description="Complete!") | |
# Create performance info | |
perf_info = f"""**⚡ Performance Stats:** | |
- **Mode:** {stats['mode']} | |
- **Generation Time:** {stats['generation_time']:.2f}s | |
- **Tokens Replaced:** {stats['tokens_replaced']} | |
- **Speed:** {stats['tokens_per_second']:.1f} tokens/sec | |
- **Words Generated:** {stats['word_count']} words | |
- **Words/Second:** {stats['words_per_second']:.1f} | |
- **Steps:** {stats['steps']}""" | |
# Update history | |
history.append([message, response]) | |
# Cleanup memory for Zero GPU efficiency | |
cleanup_memory() | |
return history, "", perf_info | |
except Exception as e: | |
error_msg = f"Error: {str(e)}" | |
history.append([message, error_msg]) | |
cleanup_memory() | |
return history, "", f"**❌ Error occurred during generation**" | |
def clear_chat(): | |
"""Clear chat history and cleanup memory""" | |
cleanup_memory() | |
return [], "" | |
# Create Gradio interface | |
def create_interface(): | |
with gr.Blocks( | |
title="AR-Diffusion Chat - Experimental Model", | |
theme=gr.themes.Soft(), | |
css=""" | |
.warning-box { | |
background-color: #fff3cd; | |
border: 1px solid #ffeaa7; | |
border-radius: 5px; | |
padding: 10px; | |
margin: 10px 0; | |
} | |
""" | |
) as interface: | |
gr.HTML(""" | |
<div style="text-align: center; margin-bottom: 20px;"> | |
<h1>🧪 AR-Diffusion Chat Interface</h1> | |
<p><strong>⚠️ EXPERIMENTAL MODEL ⚠️</strong></p> | |
<p>This is an experimental AR-Diffusion model. Results may vary and the model is still under development.</p> | |
<p><em>🔥 Powered by Zero GPU with @spaces.GPU</em></p> | |
</div> | |
""") | |
with gr.Row(): | |
with gr.Column(scale=3): | |
chatbot = gr.Chatbot( | |
[], | |
elem_id="chatbot", | |
bubble_full_width=False, | |
height=500, | |
show_label=False | |
) | |
with gr.Row(): | |
msg = gr.Textbox( | |
placeholder="Type your message here...", | |
show_label=False, | |
scale=9 | |
) | |
send_btn = gr.Button("Send", scale=1, variant="primary") | |
with gr.Row(): | |
clear_btn = gr.Button("Clear Chat", variant="secondary") | |
with gr.Column(scale=1): | |
gr.HTML(""" | |
<div class="warning-box"> | |
<h3>⚙️ Mode Selection</h3> | |
<p><strong>Quality Mode:</strong> Slower but more coherent responses (~40 steps)</p> | |
<p><strong>Speed Mode:</strong> Faster responses with decent quality (~10 steps)</p> | |
<p><em>🔥 GPU acceleration via @spaces.GPU</em></p> | |
</div> | |
""") | |
mode = gr.Radio( | |
choices=["Quality (Slower, Better)", "Speed (Faster)"], | |
value="Quality (Slower, Better)", | |
label="Generation Mode" | |
) | |
# Performance display | |
perf_display = gr.Markdown( | |
"**⚡ Performance Stats:** *Generate a message to see stats*", | |
elem_id="performance" | |
) | |
gr.HTML(""" | |
<div class="warning-box"> | |
<h3>ℹ️ About AR-Diffusion</h3> | |
<p>This experimental model uses autoregressive diffusion for text generation, creating responses by iteratively denoising masked tokens.</p> | |
<br> | |
<p><strong>Note:</strong> This model is experimental and may produce unexpected results.</p> | |
</div> | |
""") | |
# Event handlers | |
def submit_message(message, history, mode): | |
return chat_function(message, history, mode) | |
send_btn.click( | |
submit_message, | |
inputs=[msg, chatbot, mode], | |
outputs=[chatbot, msg, perf_display] | |
) | |
msg.submit( | |
submit_message, | |
inputs=[msg, chatbot, mode], | |
outputs=[chatbot, msg, perf_display] | |
) | |
clear_btn.click( | |
clear_chat, | |
outputs=[chatbot, perf_display] | |
) | |
return interface | |
# Launch interface | |
if __name__ == "__main__": | |
demo = create_interface() | |
demo.queue(max_size=20) # Important for Zero GPU | |
demo.launch( | |
share=False, | |
server_name="0.0.0.0", | |
server_port=7860, | |
show_error=True | |
) |