Scriptr-Gemma / app.py
Sidharthan's picture
Added application file
bfb6e0a
raw
history blame
7.5 kB
import streamlit as st
from transformers import AutoTokenizer
from peft import AutoPeftModelForCausalLM
import torch
import re
from transformers import StoppingCriteria, StoppingCriteriaList
# Initialize session state variables if they don't exist
if 'messages' not in st.session_state:
st.session_state.messages = []
if 'conversation_history' not in st.session_state:
st.session_state.conversation_history = ""
# Load the model from huggingface.
def load_model():
try:
# Check CUDA availability
if torch.cuda.is_available():
device = torch.device("cuda")
st.success(f"Using GPU: {torch.cuda.get_device_name(0)}")
else:
device = torch.device("cpu")
st.warning("CUDA is not available. Using CPU.")
# Fine-tuned model for generating scripts
model_name = "Sidharthan/gemma2_scripter"
tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True
)
# Load model with appropriate device settings
model = AutoPeftModelForCausalLM.from_pretrained(
model_name,
device_map=None, # We'll handle device placement manually
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
trust_remote_code=True,
low_cpu_mem_usage=True
)
# Move model to device
model = model.to(device)
return model, tokenizer
except Exception as e:
st.error(f"Error loading model: {str(e)}")
raise e
class StopWordCriteria(StoppingCriteria):
def __init__(self, tokenizer, stop_word):
self.stop_word_id = tokenizer.encode(stop_word, add_special_tokens=False)
def __call__(self, input_ids, scores, **kwargs):
# Check if the last token(s) match the stop word
if len(input_ids[0]) >= len(self.stop_word_id) and input_ids[0][-len(self.stop_word_id):].tolist() == self.stop_word_id:
return True
return False
def generate_text(prompt, model, tokenizer, params, last_user_prompt=""):
# Determine the device
device = next(model.parameters()).device
# Tokenize and move to the correct device
inputs = tokenizer(prompt, return_tensors='pt')
inputs = {k: v.to(device) for k, v in inputs.items()}
stop_word = 'script'
stopping_criteria = StoppingCriteriaList([StopWordCriteria(tokenizer, stop_word)])
try:
outputs = model.generate(
**inputs,
max_length=params['max_length'],
do_sample=True,
temperature=params['temperature'],
top_p=params['top_p'],
top_k=params['top_k'],
repetition_penalty=params['repetition_penalty'],
num_return_sequences=1,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
stopping_criteria=stopping_criteria
)
# Move outputs back to CPU for decoding
outputs = outputs.cpu()
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("Response from the model:", response)
# Clean up unwanted patterns
response = re.sub(r'user\s.*?model\s', '', response, flags=re.DOTALL)
response = re.sub(r'keywords\s.*?script\s', '', response, flags=re.DOTALL)
response = re.sub(r'\bscript\b.*$', '', response, flags=re.IGNORECASE).strip()
# Remove previous prompt if repeated in response
print("Last user prompt:", last_user_prompt)
if last_user_prompt and last_user_prompt in response:
response = response.replace(last_user_prompt, "").strip()
return response
except RuntimeError as e:
if "out of memory" in str(e):
st.error("GPU out of memory error. Try reducing max_length or using CPU.")
return "Error: GPU out of memory"
else:
st.error(f"Error during generation: {str(e)}")
return f"Error during generation: {str(e)}"
def main():
st.title("🤖 LLM Chat Interface")
# Sidebar for model parameters
st.sidebar.title("Model Parameters")
params = {
'max_length': st.sidebar.selectbox('Max Length', options=[64, 128, 256, 512, 1024], index=3),
'temperature': st.sidebar.selectbox('Temperature', options=[0.2, 0.5, 0.7, 0.9, 1.0], index=2),
'top_p': st.sidebar.selectbox('Top P', options=[0.7, 0.8, 0.9, 0.95, 1.0], index=3),
'top_k': st.sidebar.selectbox('Top K', options=[10, 20, 50, 100], index=2),
'repetition_penalty': st.sidebar.selectbox('Repetition Penalty', options=[1.0, 1.1, 1.2, 1.3, 1.5], index=2)
}
# Load model and tokenizer
@st.cache_resource
def get_model():
return load_model()
model, tokenizer = get_model()
# Chat interface
st.markdown("### Chat Interface")
# Display the full conversation history
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# Input area
input_mode = st.selectbox(
"Select Mode",
["Conversation", "Script Generation"],
key="input_mode"
)
# Chat input
if prompt := st.chat_input("Enter your message"):
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
# Prepare prompt based on selected mode
if input_mode == "Conversation":
# Add new user input to conversation history
if st.session_state.conversation_history:
full_prompt = f"{st.session_state.conversation_history}\n<bos><start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n"
else:
full_prompt = f"<bos><start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n"
else:
# Script generation mode
full_prompt = f"<bos><start_of_turn>keywords\n{prompt}<end_of_turn>\n<start_of_turn>script\n"
# Generate response
with st.chat_message("assistant"):
with st.spinner("Thinking..."):
response = generate_text(full_prompt, model, tokenizer, params, last_user_prompt=prompt)
st.markdown(response)
st.session_state.messages.append({"role": "assistant", "content": response})
# Update conversation history for the model (not displayed)
if input_mode == "Conversation":
if st.session_state.conversation_history:
st.session_state.conversation_history = (
f"{st.session_state.conversation_history}"
f"<bos><start_of_turn>user\n{prompt}<end_of_turn>"
f"<start_of_turn>model\n{response}"
)
else:
st.session_state.conversation_history = (
f"<bos><start_of_turn>user\n{prompt}<end_of_turn>"
f"<start_of_turn>model\n{response}"
)
# Clear chat button
if st.button("Clear Chat"):
st.session_state.messages = []
st.session_state.conversation_history = ""
st.experimental_rerun()
if __name__ == "__main__":
main()