File size: 7,373 Bytes
5dd070e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
import streamlit as st
import json
import os
from utils import add_log

# Initialize huggingface_models in session state if not present
if 'huggingface_models' not in st.session_state:
    st.session_state.huggingface_models = [
        "codegen-350M-mono",
        "codegen-2B-mono",
        "Salesforce/codegen-350M-mono",
        "Salesforce/codegen-2B-mono",
        "gpt2",
        "EleutherAI/gpt-neo-125M"
    ]

# Handle missing dependencies
try:
    import torch
    from transformers import AutoTokenizer, AutoModelForCausalLM
    TRANSFORMERS_AVAILABLE = True
except ImportError:
    TRANSFORMERS_AVAILABLE = False
    
    # Mock classes for demo purposes
    class DummyTokenizer:
        @classmethod
        def from_pretrained(cls, model_name):
            return cls()
            
        def __call__(self, text, **kwargs):
            return {"input_ids": [list(range(10))] * (1 if isinstance(text, str) else len(text))}
            
        def decode(self, token_ids, **kwargs):
            return "# Generated code placeholder\n\ndef example_function():\n    return 'Hello world!'"
        
        @property
        def eos_token(self):
            return "[EOS]"
            
        @property
        def eos_token_id(self):
            return 0
            
        @property
        def pad_token(self):
            return None
            
        @pad_token.setter
        def pad_token(self, value):
            pass
            
    class DummyModel:
        @classmethod
        def from_pretrained(cls, model_name):
            return cls()
            
        def generate(self, input_ids, **kwargs):
            return [[1, 2, 3, 4, 5]]
            
        @property
        def config(self):
            class Config:
                @property
                def eos_token_id(self):
                    return 0
                    
                @property
                def pad_token_id(self):
                    return 0
                
                @pad_token_id.setter
                def pad_token_id(self, value):
                    pass
                    
            return Config()
            
    # Set aliases to match transformers
    AutoTokenizer = DummyTokenizer
    AutoModelForCausalLM = DummyModel

def list_available_huggingface_models():
    """
    List available code generation models from Hugging Face.
    
    Returns:
        list: List of model names
    """
    # Return the list stored in session state
    return st.session_state.huggingface_models

def get_model_and_tokenizer(model_name):
    """
    Load model and tokenizer from Hugging Face Hub.
    
    Args:
        model_name: Name of the model to load
        
    Returns:
        tuple: (model, tokenizer) or (None, None) if loading fails
    """
    try:
        add_log(f"Loading model and tokenizer: {model_name}")
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForCausalLM.from_pretrained(model_name)
        add_log(f"Model and tokenizer loaded successfully: {model_name}")
        return model, tokenizer
    except Exception as e:
        add_log(f"Error loading model {model_name}: {str(e)}", "ERROR")
        return None, None

def save_trained_model(model_id, model, tokenizer):
    """
    Save trained model information to session state.
    
    Args:
        model_id: Identifier for the model
        model: The trained model
        tokenizer: The model's tokenizer
        
    Returns:
        bool: Success status
    """
    try:
        # Store model information in session state
        from datetime import datetime
        st.session_state.trained_models[model_id] = {
            'model': model,
            'tokenizer': tokenizer,
            'info': {
                'id': model_id,
                'created_at': datetime.now().strftime("%Y-%m-%d %H:%M:%S")
            }
        }
        add_log(f"Model {model_id} saved to session state")
        return True
    except Exception as e:
        add_log(f"Error saving model {model_id}: {str(e)}", "ERROR")
        return False

def list_trained_models():
    """
    List all trained models in session state.
    
    Returns:
        list: List of model IDs
    """
    if 'trained_models' in st.session_state:
        return list(st.session_state.trained_models.keys())
    return []

def generate_code(model_id, prompt, max_length=100, temperature=0.7, top_p=0.9):
    """
    Generate code using a trained model.
    
    Args:
        model_id: ID of the model to use
        prompt: Input prompt for code generation
        max_length: Maximum length of generated text
        temperature: Sampling temperature
        top_p: Nucleus sampling probability
        
    Returns:
        str: Generated code or error message
    """
    try:
        if model_id not in st.session_state.trained_models:
            return "Error: Model not found. Please select a valid model."
        
        model_data = st.session_state.trained_models[model_id]
        model = model_data['model']
        tokenizer = model_data['tokenizer']
        
        if TRANSFORMERS_AVAILABLE:
            # Tokenize the prompt
            inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
            
            # Generate text
            with torch.no_grad():
                outputs = model.generate(
                    inputs.input_ids,
                    max_length=max_length,
                    temperature=temperature,
                    top_p=top_p,
                    num_return_sequences=1,
                    pad_token_id=tokenizer.eos_token_id
                )
            
            # Decode the generated text
            generated_code = tokenizer.decode(outputs[0], skip_special_tokens=True)
        else:
            # Demo mode - return dummy generated code
            inputs = tokenizer(prompt)
            outputs = model.generate(inputs["input_ids"])
            generated_code = tokenizer.decode(outputs[0])
            
            # Add some context to the generated code based on the prompt
            if "fibonacci" in prompt.lower():
                generated_code = "def fibonacci(n):\n    if n <= 0:\n        return 0\n    elif n == 1:\n        return 1\n    else:\n        return fibonacci(n-1) + fibonacci(n-2)\n"
            elif "sort" in prompt.lower():
                generated_code = "def bubble_sort(arr):\n    n = len(arr)\n    for i in range(n):\n        for j in range(0, n-i-1):\n            if arr[j] > arr[j+1]:\n                arr[j], arr[j+1] = arr[j+1], arr[j]\n    return arr\n"
        
        # If the prompt is included in the output, remove it to get only the generated code
        if generated_code.startswith(prompt):
            generated_code = generated_code[len(prompt):]
            
        return generated_code
        
    except Exception as e:
        add_log(f"Error generating code: {str(e)}", "ERROR")
        return f"Error generating code: {str(e)}"

def get_model_info(model_id):
    """
    Get information about a model.
    
    Args:
        model_id: ID of the model
        
    Returns:
        dict: Model information
    """
    if 'trained_models' in st.session_state and model_id in st.session_state.trained_models:
        return st.session_state.trained_models[model_id]['info']
    return None