Sidharthan commited on
Commit
bfb6e0a
·
1 Parent(s): 62eb74f

Added application file

Browse files
Files changed (1) hide show
  1. app.py +195 -0
app.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import AutoTokenizer
3
+ from peft import AutoPeftModelForCausalLM
4
+ import torch
5
+ import re
6
+ from transformers import StoppingCriteria, StoppingCriteriaList
7
+
8
+ # Initialize session state variables if they don't exist
9
+ if 'messages' not in st.session_state:
10
+ st.session_state.messages = []
11
+ if 'conversation_history' not in st.session_state:
12
+ st.session_state.conversation_history = ""
13
+
14
+ # Load the model from huggingface.
15
+ def load_model():
16
+ try:
17
+ # Check CUDA availability
18
+ if torch.cuda.is_available():
19
+ device = torch.device("cuda")
20
+ st.success(f"Using GPU: {torch.cuda.get_device_name(0)}")
21
+ else:
22
+ device = torch.device("cpu")
23
+ st.warning("CUDA is not available. Using CPU.")
24
+
25
+ # Fine-tuned model for generating scripts
26
+ model_name = "Sidharthan/gemma2_scripter"
27
+
28
+ tokenizer = AutoTokenizer.from_pretrained(
29
+ model_name,
30
+ trust_remote_code=True
31
+ )
32
+
33
+ # Load model with appropriate device settings
34
+ model = AutoPeftModelForCausalLM.from_pretrained(
35
+ model_name,
36
+ device_map=None, # We'll handle device placement manually
37
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
38
+ trust_remote_code=True,
39
+ low_cpu_mem_usage=True
40
+ )
41
+
42
+ # Move model to device
43
+ model = model.to(device)
44
+
45
+ return model, tokenizer
46
+
47
+ except Exception as e:
48
+ st.error(f"Error loading model: {str(e)}")
49
+ raise e
50
+
51
+
52
+ class StopWordCriteria(StoppingCriteria):
53
+ def __init__(self, tokenizer, stop_word):
54
+ self.stop_word_id = tokenizer.encode(stop_word, add_special_tokens=False)
55
+
56
+ def __call__(self, input_ids, scores, **kwargs):
57
+ # Check if the last token(s) match the stop word
58
+ if len(input_ids[0]) >= len(self.stop_word_id) and input_ids[0][-len(self.stop_word_id):].tolist() == self.stop_word_id:
59
+ return True
60
+ return False
61
+
62
+ def generate_text(prompt, model, tokenizer, params, last_user_prompt=""):
63
+ # Determine the device
64
+ device = next(model.parameters()).device
65
+
66
+ # Tokenize and move to the correct device
67
+ inputs = tokenizer(prompt, return_tensors='pt')
68
+ inputs = {k: v.to(device) for k, v in inputs.items()}
69
+
70
+ stop_word = 'script'
71
+ stopping_criteria = StoppingCriteriaList([StopWordCriteria(tokenizer, stop_word)])
72
+
73
+ try:
74
+ outputs = model.generate(
75
+ **inputs,
76
+ max_length=params['max_length'],
77
+ do_sample=True,
78
+ temperature=params['temperature'],
79
+ top_p=params['top_p'],
80
+ top_k=params['top_k'],
81
+ repetition_penalty=params['repetition_penalty'],
82
+ num_return_sequences=1,
83
+ pad_token_id=tokenizer.pad_token_id,
84
+ eos_token_id=tokenizer.eos_token_id,
85
+ stopping_criteria=stopping_criteria
86
+ )
87
+
88
+ # Move outputs back to CPU for decoding
89
+ outputs = outputs.cpu()
90
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
91
+ print("Response from the model:", response)
92
+
93
+ # Clean up unwanted patterns
94
+ response = re.sub(r'user\s.*?model\s', '', response, flags=re.DOTALL)
95
+ response = re.sub(r'keywords\s.*?script\s', '', response, flags=re.DOTALL)
96
+ response = re.sub(r'\bscript\b.*$', '', response, flags=re.IGNORECASE).strip()
97
+
98
+ # Remove previous prompt if repeated in response
99
+ print("Last user prompt:", last_user_prompt)
100
+ if last_user_prompt and last_user_prompt in response:
101
+
102
+ response = response.replace(last_user_prompt, "").strip()
103
+
104
+ return response
105
+
106
+ except RuntimeError as e:
107
+ if "out of memory" in str(e):
108
+ st.error("GPU out of memory error. Try reducing max_length or using CPU.")
109
+ return "Error: GPU out of memory"
110
+ else:
111
+ st.error(f"Error during generation: {str(e)}")
112
+ return f"Error during generation: {str(e)}"
113
+
114
+ def main():
115
+ st.title("🤖 LLM Chat Interface")
116
+
117
+ # Sidebar for model parameters
118
+ st.sidebar.title("Model Parameters")
119
+ params = {
120
+ 'max_length': st.sidebar.selectbox('Max Length', options=[64, 128, 256, 512, 1024], index=3),
121
+ 'temperature': st.sidebar.selectbox('Temperature', options=[0.2, 0.5, 0.7, 0.9, 1.0], index=2),
122
+ 'top_p': st.sidebar.selectbox('Top P', options=[0.7, 0.8, 0.9, 0.95, 1.0], index=3),
123
+ 'top_k': st.sidebar.selectbox('Top K', options=[10, 20, 50, 100], index=2),
124
+ 'repetition_penalty': st.sidebar.selectbox('Repetition Penalty', options=[1.0, 1.1, 1.2, 1.3, 1.5], index=2)
125
+ }
126
+
127
+ # Load model and tokenizer
128
+ @st.cache_resource
129
+ def get_model():
130
+ return load_model()
131
+
132
+ model, tokenizer = get_model()
133
+
134
+ # Chat interface
135
+ st.markdown("### Chat Interface")
136
+
137
+ # Display the full conversation history
138
+ for message in st.session_state.messages:
139
+ with st.chat_message(message["role"]):
140
+ st.markdown(message["content"])
141
+
142
+ # Input area
143
+ input_mode = st.selectbox(
144
+ "Select Mode",
145
+ ["Conversation", "Script Generation"],
146
+ key="input_mode"
147
+ )
148
+
149
+ # Chat input
150
+ if prompt := st.chat_input("Enter your message"):
151
+ # Add user message to chat history
152
+ st.session_state.messages.append({"role": "user", "content": prompt})
153
+ with st.chat_message("user"):
154
+ st.markdown(prompt)
155
+
156
+ # Prepare prompt based on selected mode
157
+ if input_mode == "Conversation":
158
+ # Add new user input to conversation history
159
+ if st.session_state.conversation_history:
160
+ full_prompt = f"{st.session_state.conversation_history}\n<bos><start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n"
161
+ else:
162
+ full_prompt = f"<bos><start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n"
163
+ else:
164
+ # Script generation mode
165
+ full_prompt = f"<bos><start_of_turn>keywords\n{prompt}<end_of_turn>\n<start_of_turn>script\n"
166
+
167
+ # Generate response
168
+ with st.chat_message("assistant"):
169
+ with st.spinner("Thinking..."):
170
+ response = generate_text(full_prompt, model, tokenizer, params, last_user_prompt=prompt)
171
+ st.markdown(response)
172
+ st.session_state.messages.append({"role": "assistant", "content": response})
173
+
174
+ # Update conversation history for the model (not displayed)
175
+ if input_mode == "Conversation":
176
+ if st.session_state.conversation_history:
177
+ st.session_state.conversation_history = (
178
+ f"{st.session_state.conversation_history}"
179
+ f"<bos><start_of_turn>user\n{prompt}<end_of_turn>"
180
+ f"<start_of_turn>model\n{response}"
181
+ )
182
+ else:
183
+ st.session_state.conversation_history = (
184
+ f"<bos><start_of_turn>user\n{prompt}<end_of_turn>"
185
+ f"<start_of_turn>model\n{response}"
186
+ )
187
+
188
+ # Clear chat button
189
+ if st.button("Clear Chat"):
190
+ st.session_state.messages = []
191
+ st.session_state.conversation_history = ""
192
+ st.experimental_rerun()
193
+
194
+ if __name__ == "__main__":
195
+ main()