File size: 6,788 Bytes
bfb6e0a
8001965
bfb6e0a
 
 
8001965
97c8b2b
bfb6e0a
8001965
97c8b2b
 
 
 
 
 
 
 
 
 
 
bfb6e0a
 
 
8001965
 
 
 
 
 
 
 
 
bfb6e0a
8001965
bfb6e0a
 
8001965
 
bfb6e0a
8001965
bfb6e0a
 
8001965
 
 
 
 
 
 
 
 
 
 
 
bfb6e0a
8001965
 
 
 
 
 
 
 
 
 
 
a4e95d0
8001965
 
 
 
 
 
 
 
 
 
 
 
bfb6e0a
 
8001965
bfb6e0a
 
97c8b2b
bfb6e0a
 
97c8b2b
 
 
 
bfb6e0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8001965
 
bfb6e0a
 
97c8b2b
bfb6e0a
 
 
 
 
 
 
 
 
 
 
 
 
 
97c8b2b
bfb6e0a
 
97c8b2b
bfb6e0a
97c8b2b
 
 
 
 
bfb6e0a
 
 
 
 
 
 
8001965
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bfb6e0a
8001965
 
 
bfb6e0a
 
97c8b2b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import streamlit as st
from transformers import AutoTokenizer
from peft import AutoPeftModelForCausalLM
import torch
import re
from transformers import StoppingCriteria, StoppingCriteriaList
import os

# Set cache directory and get token
os.environ['HF_HOME'] = '/app/cache'
hf_token = os.getenv('HF_TOKEN')

class StopWordCriteria(StoppingCriteria):
    def __init__(self, tokenizer, stop_word):
        self.stop_word_id = tokenizer.encode(stop_word, add_special_tokens=False)

    def __call__(self, input_ids, scores, **kwargs):
        if len(input_ids[0]) >= len(self.stop_word_id) and input_ids[0][-len(self.stop_word_id):].tolist() == self.stop_word_id:
            return True
        return False

def load_model():
    try:
        # Ensure cache directory exists
        cache_dir = '/app/cache'
        os.makedirs(cache_dir, exist_ok=True)
        
        # Check for HF token
        if not hf_token:
            st.warning("HuggingFace token not found. Some models may not be accessible.")
        
        # Check CUDA availability
        if torch.cuda.is_available():
            device = torch.device("cuda")
            st.success(f"Using GPU: {torch.cuda.get_device_name(0)}")
        else:
            device = torch.device("cpu")
            st.warning("CUDA is not available. Using CPU.")

        # Fine-tuned model for generating scripts
        model_name = "Sidharthan/gemma2_scripter"
        
        try:
            tokenizer = AutoTokenizer.from_pretrained(
                model_name,
                trust_remote_code=True,
                token=hf_token,
                cache_dir=cache_dir
            )
        except Exception as e:
            st.error(f"Error loading tokenizer: {str(e)}")
            if "401" in str(e):
                st.error("Authentication error. Please check your HuggingFace token.")
            raise e
        
        try:
            # Load model with appropriate device settings
            model = AutoPeftModelForCausalLM.from_pretrained(
                model_name,
                device_map=None,  # We'll handle device placement manually
                torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
                trust_remote_code=True,
                low_cpu_mem_usage=True,
                token=hf_token,
                cache_dir=cache_dir
            )
            
            # Move model to device
            model = model.to(device)
            
            return model, tokenizer

        except Exception as e:
            st.error(f"Error loading model: {str(e)}")
            if "401" in str(e):
                st.error("Authentication error. Please check your HuggingFace token.")
            elif "disk space" in str(e).lower():
                st.error("Insufficient disk space in cache directory.")
            raise e

    except Exception as e:
        st.error(f"General error during model loading: {str(e)}")
        raise e

def generate_script(tags, model, tokenizer, params):
    device = next(model.parameters()).device
    
    # Create prompt with tags
    prompt = f"<bos><start_of_turn>keywords\n{tags}<end_of_turn>\n<start_of_turn>script\n"
    
    # Tokenize and move to device
    inputs = tokenizer(prompt, return_tensors='pt')
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    stop_word = 'script'
    stopping_criteria = StoppingCriteriaList([StopWordCriteria(tokenizer, stop_word)])
    
    try:
        outputs = model.generate(
            **inputs,
            max_length=params['max_length'],
            do_sample=True,
            temperature=params['temperature'],
            top_p=params['top_p'],
            top_k=params['top_k'],
            repetition_penalty=params['repetition_penalty'],
            num_return_sequences=1,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
            stopping_criteria=stopping_criteria
        )
        
        # Move outputs back to CPU for decoding
        outputs = outputs.cpu()
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # Clean up response
        response = re.sub(r'keywords\s.*?script\s', '', response, flags=re.DOTALL)
        response = re.sub(r'\bscript\b.*$', '', response, flags=re.IGNORECASE).strip()
        
        return response
        
    except RuntimeError as e:
        if "out of memory" in str(e):
            st.error("GPU out of memory error. Try reducing max_length or using CPU.")
            return "Error: GPU out of memory"
        else:
            st.error(f"Error during generation: {str(e)}")
            return f"Error during generation: {str(e)}"

def main():
    st.title("🎥 YouTube Script Generator")
    
    # Sidebar for model parameters
    st.sidebar.title("Generation Parameters")
    params = {
        'max_length': st.sidebar.slider('Max Length', 64, 1024, 512),
        'temperature': st.sidebar.slider('Temperature', 0.1, 1.0, 0.7),
        'top_p': st.sidebar.slider('Top P', 0.1, 1.0, 0.95),
        'top_k': st.sidebar.slider('Top K', 1, 100, 50),
        'repetition_penalty': st.sidebar.slider('Repetition Penalty', 1.0, 2.0, 1.2)
    }
    
    # Load model and tokenizer
    @st.cache_resource
    def get_model():
        return load_model()
    
    try:
        model, tokenizer = get_model()
        
        # Tag input section
        st.markdown("### Add Tags")
        st.markdown("Enter tags separated by commas to generate a YouTube script")
        
        # Create columns for tag input and generate button
        col1, col2 = st.columns([3, 1])
        
        with col1:
            tags = st.text_input("Enter tags", placeholder="tech, AI, future, innovations...")
        
        with col2:
            generate_button = st.button("Generate Script", type="primary")
        
        # Generated script section
        if generate_button and tags:
            st.markdown("### Generated Script")
            with st.spinner("Generating script..."):
                script = generate_script(tags, model, tokenizer, params)
                st.text_area("Your script:", value=script, height=400)
                
                # Add download button
                st.download_button(
                    label="Download Script",
                    data=script,
                    file_name="youtube_script.txt",
                    mime="text/plain"
                )
        
        elif generate_button and not tags:
            st.warning("Please enter some tags first!")
            
    except Exception as e:
        st.error("Failed to initialize the application. Please check the logs for details.")
        st.error(f"Error: {str(e)}")

if __name__ == "__main__":
    main()