import streamlit as st from transformers import AutoTokenizer from peft import AutoPeftModelForCausalLM import torch import re from transformers import StoppingCriteria, StoppingCriteriaList # Initialize session state variables if they don't exist if 'messages' not in st.session_state: st.session_state.messages = [] if 'conversation_history' not in st.session_state: st.session_state.conversation_history = "" # Load the model from huggingface. def load_model(): try: # Check CUDA availability if torch.cuda.is_available(): device = torch.device("cuda") st.success(f"Using GPU: {torch.cuda.get_device_name(0)}") else: device = torch.device("cpu") st.warning("CUDA is not available. Using CPU.") # Fine-tuned model for generating scripts model_name = "Sidharthan/gemma2_scripter" tokenizer = AutoTokenizer.from_pretrained( model_name, trust_remote_code=True ) # Load model with appropriate device settings model = AutoPeftModelForCausalLM.from_pretrained( model_name, device_map=None, # We'll handle device placement manually torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, trust_remote_code=True, low_cpu_mem_usage=True ) # Move model to device model = model.to(device) return model, tokenizer except Exception as e: st.error(f"Error loading model: {str(e)}") raise e class StopWordCriteria(StoppingCriteria): def __init__(self, tokenizer, stop_word): self.stop_word_id = tokenizer.encode(stop_word, add_special_tokens=False) def __call__(self, input_ids, scores, **kwargs): # Check if the last token(s) match the stop word if len(input_ids[0]) >= len(self.stop_word_id) and input_ids[0][-len(self.stop_word_id):].tolist() == self.stop_word_id: return True return False def generate_text(prompt, model, tokenizer, params, last_user_prompt=""): # Determine the device device = next(model.parameters()).device # Tokenize and move to the correct device inputs = tokenizer(prompt, return_tensors='pt') inputs = {k: v.to(device) for k, v in inputs.items()} stop_word = 'script' stopping_criteria = StoppingCriteriaList([StopWordCriteria(tokenizer, stop_word)]) try: outputs = model.generate( **inputs, max_length=params['max_length'], do_sample=True, temperature=params['temperature'], top_p=params['top_p'], top_k=params['top_k'], repetition_penalty=params['repetition_penalty'], num_return_sequences=1, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, stopping_criteria=stopping_criteria ) # Move outputs back to CPU for decoding outputs = outputs.cpu() response = tokenizer.decode(outputs[0], skip_special_tokens=True) print("Response from the model:", response) # Clean up unwanted patterns response = re.sub(r'user\s.*?model\s', '', response, flags=re.DOTALL) response = re.sub(r'keywords\s.*?script\s', '', response, flags=re.DOTALL) response = re.sub(r'\bscript\b.*$', '', response, flags=re.IGNORECASE).strip() # Remove previous prompt if repeated in response print("Last user prompt:", last_user_prompt) if last_user_prompt and last_user_prompt in response: response = response.replace(last_user_prompt, "").strip() return response except RuntimeError as e: if "out of memory" in str(e): st.error("GPU out of memory error. Try reducing max_length or using CPU.") return "Error: GPU out of memory" else: st.error(f"Error during generation: {str(e)}") return f"Error during generation: {str(e)}" def main(): st.title("🤖 LLM Chat Interface") # Sidebar for model parameters st.sidebar.title("Model Parameters") params = { 'max_length': st.sidebar.selectbox('Max Length', options=[64, 128, 256, 512, 1024], index=3), 'temperature': st.sidebar.selectbox('Temperature', options=[0.2, 0.5, 0.7, 0.9, 1.0], index=2), 'top_p': st.sidebar.selectbox('Top P', options=[0.7, 0.8, 0.9, 0.95, 1.0], index=3), 'top_k': st.sidebar.selectbox('Top K', options=[10, 20, 50, 100], index=2), 'repetition_penalty': st.sidebar.selectbox('Repetition Penalty', options=[1.0, 1.1, 1.2, 1.3, 1.5], index=2) } # Load model and tokenizer @st.cache_resource def get_model(): return load_model() model, tokenizer = get_model() # Chat interface st.markdown("### Chat Interface") # Display the full conversation history for message in st.session_state.messages: with st.chat_message(message["role"]): st.markdown(message["content"]) # Input area input_mode = st.selectbox( "Select Mode", ["Conversation", "Script Generation"], key="input_mode" ) # Chat input if prompt := st.chat_input("Enter your message"): # Add user message to chat history st.session_state.messages.append({"role": "user", "content": prompt}) with st.chat_message("user"): st.markdown(prompt) # Prepare prompt based on selected mode if input_mode == "Conversation": # Add new user input to conversation history if st.session_state.conversation_history: full_prompt = f"{st.session_state.conversation_history}\nuser\n{prompt}\nmodel\n" else: full_prompt = f"user\n{prompt}\nmodel\n" else: # Script generation mode full_prompt = f"keywords\n{prompt}\nscript\n" # Generate response with st.chat_message("assistant"): with st.spinner("Thinking..."): response = generate_text(full_prompt, model, tokenizer, params, last_user_prompt=prompt) st.markdown(response) st.session_state.messages.append({"role": "assistant", "content": response}) # Update conversation history for the model (not displayed) if input_mode == "Conversation": if st.session_state.conversation_history: st.session_state.conversation_history = ( f"{st.session_state.conversation_history}" f"user\n{prompt}" f"model\n{response}" ) else: st.session_state.conversation_history = ( f"user\n{prompt}" f"model\n{response}" ) # Clear chat button if st.button("Clear Chat"): st.session_state.messages = [] st.session_state.conversation_history = "" st.experimental_rerun() if __name__ == "__main__": main()