Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import AutoTokenizer | |
from peft import AutoPeftModelForCausalLM | |
import torch | |
import re | |
from transformers import StoppingCriteria, StoppingCriteriaList | |
# Initialize session state variables if they don't exist | |
if 'messages' not in st.session_state: | |
st.session_state.messages = [] | |
if 'conversation_history' not in st.session_state: | |
st.session_state.conversation_history = "" | |
# Load the model from huggingface. | |
def load_model(): | |
try: | |
# Check CUDA availability | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
st.success(f"Using GPU: {torch.cuda.get_device_name(0)}") | |
else: | |
device = torch.device("cpu") | |
st.warning("CUDA is not available. Using CPU.") | |
# Fine-tuned model for generating scripts | |
model_name = "Sidharthan/gemma2_scripter" | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_name, | |
trust_remote_code=True | |
) | |
# Load model with appropriate device settings | |
model = AutoPeftModelForCausalLM.from_pretrained( | |
model_name, | |
device_map=None, # We'll handle device placement manually | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
trust_remote_code=True, | |
low_cpu_mem_usage=True | |
) | |
# Move model to device | |
model = model.to(device) | |
return model, tokenizer | |
except Exception as e: | |
st.error(f"Error loading model: {str(e)}") | |
raise e | |
class StopWordCriteria(StoppingCriteria): | |
def __init__(self, tokenizer, stop_word): | |
self.stop_word_id = tokenizer.encode(stop_word, add_special_tokens=False) | |
def __call__(self, input_ids, scores, **kwargs): | |
# Check if the last token(s) match the stop word | |
if len(input_ids[0]) >= len(self.stop_word_id) and input_ids[0][-len(self.stop_word_id):].tolist() == self.stop_word_id: | |
return True | |
return False | |
def generate_text(prompt, model, tokenizer, params, last_user_prompt=""): | |
# Determine the device | |
device = next(model.parameters()).device | |
# Tokenize and move to the correct device | |
inputs = tokenizer(prompt, return_tensors='pt') | |
inputs = {k: v.to(device) for k, v in inputs.items()} | |
stop_word = 'script' | |
stopping_criteria = StoppingCriteriaList([StopWordCriteria(tokenizer, stop_word)]) | |
try: | |
outputs = model.generate( | |
**inputs, | |
max_length=params['max_length'], | |
do_sample=True, | |
temperature=params['temperature'], | |
top_p=params['top_p'], | |
top_k=params['top_k'], | |
repetition_penalty=params['repetition_penalty'], | |
num_return_sequences=1, | |
pad_token_id=tokenizer.pad_token_id, | |
eos_token_id=tokenizer.eos_token_id, | |
stopping_criteria=stopping_criteria | |
) | |
# Move outputs back to CPU for decoding | |
outputs = outputs.cpu() | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
print("Response from the model:", response) | |
# Clean up unwanted patterns | |
response = re.sub(r'user\s.*?model\s', '', response, flags=re.DOTALL) | |
response = re.sub(r'keywords\s.*?script\s', '', response, flags=re.DOTALL) | |
response = re.sub(r'\bscript\b.*$', '', response, flags=re.IGNORECASE).strip() | |
# Remove previous prompt if repeated in response | |
print("Last user prompt:", last_user_prompt) | |
if last_user_prompt and last_user_prompt in response: | |
response = response.replace(last_user_prompt, "").strip() | |
return response | |
except RuntimeError as e: | |
if "out of memory" in str(e): | |
st.error("GPU out of memory error. Try reducing max_length or using CPU.") | |
return "Error: GPU out of memory" | |
else: | |
st.error(f"Error during generation: {str(e)}") | |
return f"Error during generation: {str(e)}" | |
def main(): | |
st.title("🤖 LLM Chat Interface") | |
# Sidebar for model parameters | |
st.sidebar.title("Model Parameters") | |
params = { | |
'max_length': st.sidebar.selectbox('Max Length', options=[64, 128, 256, 512, 1024], index=3), | |
'temperature': st.sidebar.selectbox('Temperature', options=[0.2, 0.5, 0.7, 0.9, 1.0], index=2), | |
'top_p': st.sidebar.selectbox('Top P', options=[0.7, 0.8, 0.9, 0.95, 1.0], index=3), | |
'top_k': st.sidebar.selectbox('Top K', options=[10, 20, 50, 100], index=2), | |
'repetition_penalty': st.sidebar.selectbox('Repetition Penalty', options=[1.0, 1.1, 1.2, 1.3, 1.5], index=2) | |
} | |
# Load model and tokenizer | |
def get_model(): | |
return load_model() | |
model, tokenizer = get_model() | |
# Chat interface | |
st.markdown("### Chat Interface") | |
# Display the full conversation history | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
# Input area | |
input_mode = st.selectbox( | |
"Select Mode", | |
["Conversation", "Script Generation"], | |
key="input_mode" | |
) | |
# Chat input | |
if prompt := st.chat_input("Enter your message"): | |
# Add user message to chat history | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
with st.chat_message("user"): | |
st.markdown(prompt) | |
# Prepare prompt based on selected mode | |
if input_mode == "Conversation": | |
# Add new user input to conversation history | |
if st.session_state.conversation_history: | |
full_prompt = f"{st.session_state.conversation_history}\n<bos><start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n" | |
else: | |
full_prompt = f"<bos><start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n" | |
else: | |
# Script generation mode | |
full_prompt = f"<bos><start_of_turn>keywords\n{prompt}<end_of_turn>\n<start_of_turn>script\n" | |
# Generate response | |
with st.chat_message("assistant"): | |
with st.spinner("Thinking..."): | |
response = generate_text(full_prompt, model, tokenizer, params, last_user_prompt=prompt) | |
st.markdown(response) | |
st.session_state.messages.append({"role": "assistant", "content": response}) | |
# Update conversation history for the model (not displayed) | |
if input_mode == "Conversation": | |
if st.session_state.conversation_history: | |
st.session_state.conversation_history = ( | |
f"{st.session_state.conversation_history}" | |
f"<bos><start_of_turn>user\n{prompt}<end_of_turn>" | |
f"<start_of_turn>model\n{response}" | |
) | |
else: | |
st.session_state.conversation_history = ( | |
f"<bos><start_of_turn>user\n{prompt}<end_of_turn>" | |
f"<start_of_turn>model\n{response}" | |
) | |
# Clear chat button | |
if st.button("Clear Chat"): | |
st.session_state.messages = [] | |
st.session_state.conversation_history = "" | |
st.experimental_rerun() | |
if __name__ == "__main__": | |
main() | |