File size: 4,203 Bytes
fea4095
 
1a1fb0e
 
 
fea4095
 
 
25c11ba
fea4095
 
25c11ba
7276d4c
fea4095
 
25c11ba
fea4095
bf2292c
1a1fb0e
 
 
25c11ba
1a1fb0e
25c11ba
 
 
 
 
 
 
 
1a1fb0e
25c11ba
 
1a1fb0e
25c11ba
 
 
1a1fb0e
25c11ba
 
 
1a1fb0e
25c11ba
1a1fb0e
 
 
25c11ba
 
1a1fb0e
25c11ba
fea4095
25c11ba
fea4095
 
44302df
 
 
 
fea4095
fee88b4
25c11ba
 
 
 
1a1fb0e
25c11ba
 
 
 
 
 
fea4095
 
 
1a1fb0e
25c11ba
 
44302df
 
 
 
25c11ba
 
 
fea4095
 
1a1fb0e
25c11ba
 
fee88b4
44302df
 
 
cddc4c2
 
25c11ba
cddc4c2
 
25c11ba
 
 
 
cddc4c2
25c11ba
 
44302df
cddc4c2
25c11ba
 
 
 
cddc4c2
 
44302df
 
 
 
 
 
 
cddc4c2
25c11ba
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import hf_hub_download
import json

# Cache for model and tokenizer
MODEL = None
TOKENIZER = None

def initialize():
    global MODEL, TOKENIZER
    
    if MODEL is None:
        print("Loading model and tokenizer...")
        model_id = "jatingocodeo/SmolLM2"
        
        try:
            # Download model files from HF Hub
            config_path = hf_hub_download(repo_id=model_id, filename="config.json")
            
            # Load tokenizer
            print("Loading tokenizer...")
            TOKENIZER = AutoTokenizer.from_pretrained(model_id)
            
            # Add special tokens if needed
            special_tokens = {
                'pad_token': '[PAD]',
                'eos_token': '</s>',
                'bos_token': '<s>'
            }
            TOKENIZER.add_special_tokens(special_tokens)
            
            # Load model
            print("Loading model...")
            MODEL = AutoModelForCausalLM.from_pretrained(
                model_id,
                torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
                trust_remote_code=True,
                low_cpu_mem_usage=True
            )
            
            # Move model to device
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            MODEL.to(device)
            
            print(f"Model loaded successfully on {device}")
            
        except Exception as e:
            print(f"Error initializing: {str(e)}")
            raise

def generate_text(prompt, max_length=100, temperature=0.7, top_k=50):
    # Initialize if not already done
    if MODEL is None:
        try:
            initialize()
        except Exception as e:
            return f"Failed to initialize model: {str(e)}"
    
    try:
        # Process prompt
        if not prompt.strip():
            return "Please enter a prompt."
        
        # Add BOS token if needed
        if not prompt.startswith(TOKENIZER.bos_token):
            prompt = TOKENIZER.bos_token + prompt
        
        # Encode prompt
        input_ids = TOKENIZER.encode(prompt, return_tensors="pt", truncation=True, max_length=2048)
        input_ids = input_ids.to(MODEL.device)
        
        # Generate
        with torch.no_grad():
            outputs = MODEL.generate(
                input_ids,
                max_length=min(max_length + len(input_ids[0]), 2048),
                temperature=max(0.1, min(temperature, 1.0)),  # Clamp temperature
                top_k=max(1, min(top_k, 100)),  # Clamp top_k
                do_sample=True if temperature > 0 else False,
                num_return_sequences=1,
                pad_token_id=TOKENIZER.pad_token_id,
                eos_token_id=TOKENIZER.eos_token_id,
            )
        
        # Decode and return
        generated_text = TOKENIZER.decode(outputs[0], skip_special_tokens=True)
        return generated_text.strip()
        
    except Exception as e:
        import traceback
        traceback.print_exc()
        return f"Error during text generation: {str(e)}"

# Create Gradio interface
iface = gr.Interface(
    fn=generate_text,
    inputs=[
        gr.Textbox(label="Prompt", placeholder="Enter your prompt here...", lines=2),
        gr.Slider(minimum=10, maximum=200, value=100, step=1, label="Max Length"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top K"),
    ],
    outputs=gr.Textbox(label="Generated Text", lines=5),
    title="SmolLM2 Text Generator",
    description="Generate text using the fine-tuned SmolLM2 model. Adjust parameters to control the generation.",
    examples=[
        ["Once upon a time", 100, 0.7, 50],
        ["The quick brown fox", 150, 0.8, 40],
    ],
    allow_flagging="never"
)

# Initialize on startup
try:
    initialize()
except Exception as e:
    print(f"Warning: Model initialization failed: {str(e)}")
    print("Model will be initialized on first request")

if __name__ == "__main__":
    iface.launch()