File size: 3,967 Bytes
fea4095
 
1a1fb0e
 
 
fea4095
 
 
25c11ba
fea4095
 
25c11ba
7276d4c
fea4095
 
25c11ba
fea4095
bf2292c
1a1fb0e
 
 
25c11ba
1a1fb0e
25c11ba
 
 
 
 
 
 
 
1a1fb0e
25c11ba
 
1a1fb0e
25c11ba
 
 
1a1fb0e
25c11ba
 
 
1a1fb0e
25c11ba
1a1fb0e
 
 
25c11ba
 
1a1fb0e
25c11ba
fea4095
25c11ba
fea4095
 
 
 
fee88b4
25c11ba
 
 
 
1a1fb0e
25c11ba
 
 
 
 
 
fea4095
 
 
1a1fb0e
25c11ba
 
 
 
 
 
 
1a1fb0e
25c11ba
fea4095
 
1a1fb0e
25c11ba
 
fee88b4
1a1fb0e
 
cddc4c2
 
25c11ba
cddc4c2
 
25c11ba
 
 
 
cddc4c2
25c11ba
 
 
 
 
 
cddc4c2
25c11ba
 
 
 
 
cddc4c2
 
 
25c11ba
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import hf_hub_download
import json

# Cache for model and tokenizer
MODEL = None
TOKENIZER = None

def initialize():
    global MODEL, TOKENIZER
    
    if MODEL is None:
        print("Loading model and tokenizer...")
        model_id = "jatingocodeo/SmolLM2"
        
        try:
            # Download model files from HF Hub
            config_path = hf_hub_download(repo_id=model_id, filename="config.json")
            
            # Load tokenizer
            print("Loading tokenizer...")
            TOKENIZER = AutoTokenizer.from_pretrained(model_id)
            
            # Add special tokens if needed
            special_tokens = {
                'pad_token': '[PAD]',
                'eos_token': '</s>',
                'bos_token': '<s>'
            }
            TOKENIZER.add_special_tokens(special_tokens)
            
            # Load model
            print("Loading model...")
            MODEL = AutoModelForCausalLM.from_pretrained(
                model_id,
                torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
                trust_remote_code=True,
                low_cpu_mem_usage=True
            )
            
            # Move model to device
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            MODEL.to(device)
            
            print(f"Model loaded successfully on {device}")
            
        except Exception as e:
            print(f"Error initializing: {str(e)}")
            raise

def generate_text(prompt, max_length=100, temperature=0.7, top_k=50):
    # Initialize if not already done
    if MODEL is None:
        initialize()
    
    try:
        # Process prompt
        if not prompt.strip():
            return "Please enter a prompt."
        
        # Add BOS token if needed
        if not prompt.startswith(TOKENIZER.bos_token):
            prompt = TOKENIZER.bos_token + prompt
        
        # Encode prompt
        input_ids = TOKENIZER.encode(prompt, return_tensors="pt", truncation=True, max_length=2048)
        input_ids = input_ids.to(MODEL.device)
        
        # Generate
        with torch.no_grad():
            outputs = MODEL.generate(
                input_ids,
                max_length=min(max_length + len(input_ids[0]), 2048),
                temperature=temperature,
                top_k=top_k,
                do_sample=True,
                pad_token_id=TOKENIZER.pad_token_id,
                eos_token_id=TOKENIZER.eos_token_id,
                num_return_sequences=1
            )
        
        # Decode and return
        generated_text = TOKENIZER.decode(outputs[0], skip_special_tokens=True)
        return generated_text.strip()
        
    except Exception as e:
        print(f"Error generating text: {str(e)}")
        return f"An error occurred: {str(e)}"

# Create Gradio interface
iface = gr.Interface(
    fn=generate_text,
    inputs=[
        gr.Textbox(label="Prompt", placeholder="Enter your prompt here...", lines=2),
        gr.Slider(minimum=10, maximum=200, value=100, step=1, label="Max Length"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top K"),
    ],
    outputs=gr.Textbox(label="Generated Text", lines=5),
    title="SmolLM2 Text Generator",
    description="""Generate text using the fine-tuned SmolLM2 model.
    - Max Length: Controls the length of generated text
    - Temperature: Controls randomness (higher = more creative)
    - Top K: Controls diversity of word choices""",
    examples=[
        ["Once upon a time", 100, 0.7, 50],
        ["The quick brown fox", 150, 0.8, 40],
        ["In a galaxy far far away", 200, 0.9, 30],
    ],
    allow_flagging="never"
)

if __name__ == "__main__":
    iface.launch()