therpist2 / app.py
hackergeek98's picture
Update app.py
d3cbad9 verified
raw
history blame
703 Bytes
import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-1b-pt")
# Load base model on CPU
base_model = AutoModelForCausalLM.from_pretrained("google/gemma-3-1b-pt")
# Load fine-tuned PEFT model
model = PeftModel.from_pretrained(base_model, "hackergeek98/gemma-finetuned")
# Ensure model runs on CPU
model = model.to("cpu")
# Test inference
input_text = "Hello, how are you?"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to("cpu")
# Generate output
output = model.generate(input_ids, max_length=50)
print(tokenizer.decode(output[0], skip_special_tokens=True))