Sidharthan commited on
Commit
97c8b2b
·
1 Parent(s): 01d319b

Changed the interface and added the access tokens

Browse files
Files changed (1) hide show
  1. app.py +60 -108
app.py CHANGED
@@ -1,46 +1,46 @@
1
  import streamlit as st
2
- from transformers import AutoTokenizer
3
  from peft import AutoPeftModelForCausalLM
4
  import torch
5
  import re
6
- from transformers import StoppingCriteria, StoppingCriteriaList
7
 
8
- # Initialize session state variables if they don't exist
9
- if 'messages' not in st.session_state:
10
- st.session_state.messages = []
11
- if 'conversation_history' not in st.session_state:
12
- st.session_state.conversation_history = ""
 
 
 
 
 
 
13
 
14
- # Load the model from huggingface.
15
  def load_model():
16
  try:
17
- # Check CUDA availability
18
  if torch.cuda.is_available():
19
- device = torch.device("cuda")
20
  st.success(f"Using GPU: {torch.cuda.get_device_name(0)}")
21
  else:
22
- device = torch.device("cpu")
23
- st.warning("CUDA is not available. Using CPU.")
24
 
25
- # Fine-tuned model for generating scripts
26
  model_name = "Sidharthan/gemma2_scripter"
27
 
28
  tokenizer = AutoTokenizer.from_pretrained(
29
  model_name,
30
- trust_remote_code=True
 
31
  )
32
 
33
- # Load model with appropriate device settings
34
  model = AutoPeftModelForCausalLM.from_pretrained(
35
  model_name,
36
- device_map=None, # We'll handle device placement manually
37
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
38
  trust_remote_code=True,
39
- low_cpu_mem_usage=True
40
- )
41
-
42
- # Move model to device
43
- model = model.to(device)
44
 
45
  return model, tokenizer
46
 
@@ -48,22 +48,13 @@ def load_model():
48
  st.error(f"Error loading model: {str(e)}")
49
  raise e
50
 
51
-
52
- class StopWordCriteria(StoppingCriteria):
53
- def __init__(self, tokenizer, stop_word):
54
- self.stop_word_id = tokenizer.encode(stop_word, add_special_tokens=False)
55
-
56
- def __call__(self, input_ids, scores, **kwargs):
57
- # Check if the last token(s) match the stop word
58
- if len(input_ids[0]) >= len(self.stop_word_id) and input_ids[0][-len(self.stop_word_id):].tolist() == self.stop_word_id:
59
- return True
60
- return False
61
-
62
- def generate_text(prompt, model, tokenizer, params, last_user_prompt=""):
63
- # Determine the device
64
  device = next(model.parameters()).device
65
 
66
- # Tokenize and move to the correct device
 
 
 
67
  inputs = tokenizer(prompt, return_tensors='pt')
68
  inputs = {k: v.to(device) for k, v in inputs.items()}
69
 
@@ -85,22 +76,12 @@ def generate_text(prompt, model, tokenizer, params, last_user_prompt=""):
85
  stopping_criteria=stopping_criteria
86
  )
87
 
88
- # Move outputs back to CPU for decoding
89
- outputs = outputs.cpu()
90
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
91
- print("Response from the model:", response)
92
 
93
- # Clean up unwanted patterns
94
- response = re.sub(r'user\s.*?model\s', '', response, flags=re.DOTALL)
95
  response = re.sub(r'keywords\s.*?script\s', '', response, flags=re.DOTALL)
96
  response = re.sub(r'\bscript\b.*$', '', response, flags=re.IGNORECASE).strip()
97
 
98
- # Remove previous prompt if repeated in response
99
- print("Last user prompt:", last_user_prompt)
100
- if last_user_prompt and last_user_prompt in response:
101
-
102
- response = response.replace(last_user_prompt, "").strip()
103
-
104
  return response
105
 
106
  except RuntimeError as e:
@@ -112,16 +93,16 @@ def generate_text(prompt, model, tokenizer, params, last_user_prompt=""):
112
  return f"Error during generation: {str(e)}"
113
 
114
  def main():
115
- st.title("🤖 LLM Chat Interface")
116
 
117
  # Sidebar for model parameters
118
- st.sidebar.title("Model Parameters")
119
  params = {
120
- 'max_length': st.sidebar.selectbox('Max Length', options=[64, 128, 256, 512, 1024], index=3),
121
- 'temperature': st.sidebar.selectbox('Temperature', options=[0.2, 0.5, 0.7, 0.9, 1.0], index=2),
122
- 'top_p': st.sidebar.selectbox('Top P', options=[0.7, 0.8, 0.9, 0.95, 1.0], index=3),
123
- 'top_k': st.sidebar.selectbox('Top K', options=[10, 20, 50, 100], index=2),
124
- 'repetition_penalty': st.sidebar.selectbox('Repetition Penalty', options=[1.0, 1.1, 1.2, 1.3, 1.5], index=2)
125
  }
126
 
127
  # Load model and tokenizer
@@ -131,65 +112,36 @@ def main():
131
 
132
  model, tokenizer = get_model()
133
 
134
- # Chat interface
135
- st.markdown("### Chat Interface")
 
 
 
 
136
 
137
- # Display the full conversation history
138
- for message in st.session_state.messages:
139
- with st.chat_message(message["role"]):
140
- st.markdown(message["content"])
141
 
142
- # Input area
143
- input_mode = st.selectbox(
144
- "Select Mode",
145
- ["Conversation", "Script Generation"],
146
- key="input_mode"
147
- )
148
 
149
- # Chat input
150
- if prompt := st.chat_input("Enter your message"):
151
- # Add user message to chat history
152
- st.session_state.messages.append({"role": "user", "content": prompt})
153
- with st.chat_message("user"):
154
- st.markdown(prompt)
155
 
156
- # Prepare prompt based on selected mode
157
- if input_mode == "Conversation":
158
- # Add new user input to conversation history
159
- if st.session_state.conversation_history:
160
- full_prompt = f"{st.session_state.conversation_history}\n<bos><start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n"
161
- else:
162
- full_prompt = f"<bos><start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n"
163
- else:
164
- # Script generation mode
165
- full_prompt = f"<bos><start_of_turn>keywords\n{prompt}<end_of_turn>\n<start_of_turn>script\n"
166
-
167
- # Generate response
168
- with st.chat_message("assistant"):
169
- with st.spinner("Thinking..."):
170
- response = generate_text(full_prompt, model, tokenizer, params, last_user_prompt=prompt)
171
- st.markdown(response)
172
- st.session_state.messages.append({"role": "assistant", "content": response})
173
-
174
- # Update conversation history for the model (not displayed)
175
- if input_mode == "Conversation":
176
- if st.session_state.conversation_history:
177
- st.session_state.conversation_history = (
178
- f"{st.session_state.conversation_history}"
179
- f"<bos><start_of_turn>user\n{prompt}<end_of_turn>"
180
- f"<start_of_turn>model\n{response}"
181
- )
182
- else:
183
- st.session_state.conversation_history = (
184
- f"<bos><start_of_turn>user\n{prompt}<end_of_turn>"
185
- f"<start_of_turn>model\n{response}"
186
- )
187
 
188
- # Clear chat button
189
- if st.button("Clear Chat"):
190
- st.session_state.messages = []
191
- st.session_state.conversation_history = ""
192
- st.experimental_rerun()
193
 
194
  if __name__ == "__main__":
195
- main()
 
1
  import streamlit as st
2
+ from transformers import AutoTokenizer, StoppingCriteria, StoppingCriteriaList
3
  from peft import AutoPeftModelForCausalLM
4
  import torch
5
  import re
6
+ import os
7
 
8
+ os.environ['HF_HOME'] = '/app/cache'
9
+ hf_token = os.getenv('HF_TOKEN')
10
+
11
+ class StopWordCriteria(StoppingCriteria):
12
+ def __init__(self, tokenizer, stop_word):
13
+ self.stop_word_id = tokenizer.encode(stop_word, add_special_tokens=False)
14
+
15
+ def __call__(self, input_ids, scores, **kwargs):
16
+ if len(input_ids[0]) >= len(self.stop_word_id) and input_ids[0][-len(self.stop_word_id):].tolist() == self.stop_word_id:
17
+ return True
18
+ return False
19
 
 
20
  def load_model():
21
  try:
22
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
  if torch.cuda.is_available():
 
24
  st.success(f"Using GPU: {torch.cuda.get_device_name(0)}")
25
  else:
26
+ st.warning("Using CPU for inference")
 
27
 
 
28
  model_name = "Sidharthan/gemma2_scripter"
29
 
30
  tokenizer = AutoTokenizer.from_pretrained(
31
  model_name,
32
+ trust_remote_code=True,
33
+ token=hf_token
34
  )
35
 
 
36
  model = AutoPeftModelForCausalLM.from_pretrained(
37
  model_name,
38
+ device_map=None,
39
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
40
  trust_remote_code=True,
41
+ low_cpu_mem_usage=True,
42
+ cache_dir='/app/cache'
43
+ ).to(device)
 
 
44
 
45
  return model, tokenizer
46
 
 
48
  st.error(f"Error loading model: {str(e)}")
49
  raise e
50
 
51
+ def generate_script(tags, model, tokenizer, params):
 
 
 
 
 
 
 
 
 
 
 
 
52
  device = next(model.parameters()).device
53
 
54
+ # Create prompt with tags
55
+ prompt = f"<bos><start_of_turn>keywords\n{tags}<end_of_turn>\n<start_of_turn>script\n"
56
+
57
+ # Tokenize and move to device
58
  inputs = tokenizer(prompt, return_tensors='pt')
59
  inputs = {k: v.to(device) for k, v in inputs.items()}
60
 
 
76
  stopping_criteria=stopping_criteria
77
  )
78
 
 
 
79
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
80
 
81
+ # Clean up response
 
82
  response = re.sub(r'keywords\s.*?script\s', '', response, flags=re.DOTALL)
83
  response = re.sub(r'\bscript\b.*$', '', response, flags=re.IGNORECASE).strip()
84
 
 
 
 
 
 
 
85
  return response
86
 
87
  except RuntimeError as e:
 
93
  return f"Error during generation: {str(e)}"
94
 
95
  def main():
96
+ st.title("🎥 YouTube Script Generator")
97
 
98
  # Sidebar for model parameters
99
+ st.sidebar.title("Generation Parameters")
100
  params = {
101
+ 'max_length': st.sidebar.slider('Max Length', 64, 1024, 512),
102
+ 'temperature': st.sidebar.slider('Temperature', 0.1, 1.0, 0.7),
103
+ 'top_p': st.sidebar.slider('Top P', 0.1, 1.0, 0.95),
104
+ 'top_k': st.sidebar.slider('Top K', 1, 100, 50),
105
+ 'repetition_penalty': st.sidebar.slider('Repetition Penalty', 1.0, 2.0, 1.2)
106
  }
107
 
108
  # Load model and tokenizer
 
112
 
113
  model, tokenizer = get_model()
114
 
115
+ # Tag input section
116
+ st.markdown("### Add Tags")
117
+ st.markdown("Enter tags separated by commas to generate a YouTube script")
118
+
119
+ # Create columns for tag input and generate button
120
+ col1, col2 = st.columns([3, 1])
121
 
122
+ with col1:
123
+ tags = st.text_input("Enter tags", placeholder="tech, AI, future, innovations...")
 
 
124
 
125
+ with col2:
126
+ generate_button = st.button("Generate Script", type="primary")
 
 
 
 
127
 
128
+ # Generated script section
129
+ if generate_button and tags:
130
+ st.markdown("### Generated Script")
131
+ with st.spinner("Generating script..."):
132
+ script = generate_script(tags, model, tokenizer, params)
133
+ st.text_area("Your script:", value=script, height=400)
134
 
135
+ # Add download button
136
+ st.download_button(
137
+ label="Download Script",
138
+ data=script,
139
+ file_name="youtube_script.txt",
140
+ mime="text/plain"
141
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
+ elif generate_button and not tags:
144
+ st.warning("Please enter some tags first!")
 
 
 
145
 
146
  if __name__ == "__main__":
147
+ main()