Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
# Title of the Streamlit app | |
st.title("Neo Scalinglaw 250M Model") | |
# Text input for user prompt | |
user_input = st.text_input("Enter your prompt:") | |
# Load the tokenizer and model | |
def load_model(): | |
model_path = 'm-a-p/neo_scalinglaw_250M' | |
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto", torch_dtype='auto').eval() | |
return tokenizer, model | |
tokenizer, model = load_model() | |
# Generate text when the user inputs a prompt and presses the button | |
if st.button("Generate"): | |
if user_input: | |
with st.spinner("Generating response..."): | |
input_ids = tokenizer(user_input, add_generation_prompt=True, return_tensors='pt').to(model.device) | |
output_ids = model.generate(**input_ids, max_new_tokens=20) | |
response = tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
st.success("Generated response:") | |
st.write(response) | |
else: | |
st.error("Please enter a prompt.") | |