CodeCraftLab / model_utils.py
S-Dreamer's picture
Upload 14 files
5dd070e verified
import streamlit as st
import json
import os
from utils import add_log
# Initialize huggingface_models in session state if not present
if 'huggingface_models' not in st.session_state:
st.session_state.huggingface_models = [
"codegen-350M-mono",
"codegen-2B-mono",
"Salesforce/codegen-350M-mono",
"Salesforce/codegen-2B-mono",
"gpt2",
"EleutherAI/gpt-neo-125M"
]
# Handle missing dependencies
try:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
TRANSFORMERS_AVAILABLE = True
except ImportError:
TRANSFORMERS_AVAILABLE = False
# Mock classes for demo purposes
class DummyTokenizer:
@classmethod
def from_pretrained(cls, model_name):
return cls()
def __call__(self, text, **kwargs):
return {"input_ids": [list(range(10))] * (1 if isinstance(text, str) else len(text))}
def decode(self, token_ids, **kwargs):
return "# Generated code placeholder\n\ndef example_function():\n return 'Hello world!'"
@property
def eos_token(self):
return "[EOS]"
@property
def eos_token_id(self):
return 0
@property
def pad_token(self):
return None
@pad_token.setter
def pad_token(self, value):
pass
class DummyModel:
@classmethod
def from_pretrained(cls, model_name):
return cls()
def generate(self, input_ids, **kwargs):
return [[1, 2, 3, 4, 5]]
@property
def config(self):
class Config:
@property
def eos_token_id(self):
return 0
@property
def pad_token_id(self):
return 0
@pad_token_id.setter
def pad_token_id(self, value):
pass
return Config()
# Set aliases to match transformers
AutoTokenizer = DummyTokenizer
AutoModelForCausalLM = DummyModel
def list_available_huggingface_models():
"""
List available code generation models from Hugging Face.
Returns:
list: List of model names
"""
# Return the list stored in session state
return st.session_state.huggingface_models
def get_model_and_tokenizer(model_name):
"""
Load model and tokenizer from Hugging Face Hub.
Args:
model_name: Name of the model to load
Returns:
tuple: (model, tokenizer) or (None, None) if loading fails
"""
try:
add_log(f"Loading model and tokenizer: {model_name}")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
add_log(f"Model and tokenizer loaded successfully: {model_name}")
return model, tokenizer
except Exception as e:
add_log(f"Error loading model {model_name}: {str(e)}", "ERROR")
return None, None
def save_trained_model(model_id, model, tokenizer):
"""
Save trained model information to session state.
Args:
model_id: Identifier for the model
model: The trained model
tokenizer: The model's tokenizer
Returns:
bool: Success status
"""
try:
# Store model information in session state
from datetime import datetime
st.session_state.trained_models[model_id] = {
'model': model,
'tokenizer': tokenizer,
'info': {
'id': model_id,
'created_at': datetime.now().strftime("%Y-%m-%d %H:%M:%S")
}
}
add_log(f"Model {model_id} saved to session state")
return True
except Exception as e:
add_log(f"Error saving model {model_id}: {str(e)}", "ERROR")
return False
def list_trained_models():
"""
List all trained models in session state.
Returns:
list: List of model IDs
"""
if 'trained_models' in st.session_state:
return list(st.session_state.trained_models.keys())
return []
def generate_code(model_id, prompt, max_length=100, temperature=0.7, top_p=0.9):
"""
Generate code using a trained model.
Args:
model_id: ID of the model to use
prompt: Input prompt for code generation
max_length: Maximum length of generated text
temperature: Sampling temperature
top_p: Nucleus sampling probability
Returns:
str: Generated code or error message
"""
try:
if model_id not in st.session_state.trained_models:
return "Error: Model not found. Please select a valid model."
model_data = st.session_state.trained_models[model_id]
model = model_data['model']
tokenizer = model_data['tokenizer']
if TRANSFORMERS_AVAILABLE:
# Tokenize the prompt
inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
# Generate text
with torch.no_grad():
outputs = model.generate(
inputs.input_ids,
max_length=max_length,
temperature=temperature,
top_p=top_p,
num_return_sequences=1,
pad_token_id=tokenizer.eos_token_id
)
# Decode the generated text
generated_code = tokenizer.decode(outputs[0], skip_special_tokens=True)
else:
# Demo mode - return dummy generated code
inputs = tokenizer(prompt)
outputs = model.generate(inputs["input_ids"])
generated_code = tokenizer.decode(outputs[0])
# Add some context to the generated code based on the prompt
if "fibonacci" in prompt.lower():
generated_code = "def fibonacci(n):\n if n <= 0:\n return 0\n elif n == 1:\n return 1\n else:\n return fibonacci(n-1) + fibonacci(n-2)\n"
elif "sort" in prompt.lower():
generated_code = "def bubble_sort(arr):\n n = len(arr)\n for i in range(n):\n for j in range(0, n-i-1):\n if arr[j] > arr[j+1]:\n arr[j], arr[j+1] = arr[j+1], arr[j]\n return arr\n"
# If the prompt is included in the output, remove it to get only the generated code
if generated_code.startswith(prompt):
generated_code = generated_code[len(prompt):]
return generated_code
except Exception as e:
add_log(f"Error generating code: {str(e)}", "ERROR")
return f"Error generating code: {str(e)}"
def get_model_info(model_id):
"""
Get information about a model.
Args:
model_id: ID of the model
Returns:
dict: Model information
"""
if 'trained_models' in st.session_state and model_id in st.session_state.trained_models:
return st.session_state.trained_models[model_id]['info']
return None