ai / model.py
arya-ai-model's picture
updated model.py
5d13b40
import os
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
MODEL_NAME = "bigcode/starcoderbase-3b"
HF_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
device = "cpu"
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=HF_TOKEN)
# Ensure the tokenizer has a pad token set
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token # Set pad_token to eos_token
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
token=HF_TOKEN,
torch_dtype=torch.float32, # Ensure compatibility with CPU
trust_remote_code=True
).to(device)
def generate_code(prompt: str, max_tokens: int = 256):
formatted_prompt = f"{prompt}\n### Code:\n" # Ensure the model understands it's code
inputs = tokenizer(
formatted_prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=512 # Explicit max length to prevent issues
).to(device)
output = model.generate(
**inputs,
max_new_tokens=max_tokens,
pad_token_id=tokenizer.pad_token_id,
do_sample=True, # Enable randomness for better outputs
top_p=0.95, # Nucleus sampling to improve generation
temperature=0.7 # Control creativity
)
generated_code = tokenizer.decode(output[0], skip_special_tokens=True)
# Clean the output: remove the repeated prompt at the start
if generated_code.startswith(formatted_prompt):
generated_code = generated_code[len(formatted_prompt):]
return generated_code.strip()