torVik commited on
Commit
2087d05
·
verified ·
1 Parent(s): 229973e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +190 -4
app.py CHANGED
@@ -1,7 +1,193 @@
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
1
+ import os
2
+ from threading import Thread
3
+ from typing import Iterator
4
+
5
  import gradio as gr
6
+ import spaces
7
+ import torch
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
+
10
+ # Debugging: Start script
11
+ print("Starting script...")
12
+
13
+ HF_TOKEN = os.environ.get("HF_TOKEN")
14
+ if HF_TOKEN is None:
15
+ print("Warning: HF_TOKEN is not set!")
16
+
17
+ PASSWORD = os.getenv("APP_PASSWORD", "mysecretpassword") # Set your desired password here or via environment variable
18
+
19
+ DESCRIPTION = "# FT of Lama"
20
+
21
+ if not torch.cuda.is_available():
22
+ DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
23
+ print("Warning: No GPU available. This model cannot run on CPU.")
24
+ else:
25
+ print("GPU is available!")
26
+
27
+ MAX_MAX_NEW_TOKENS = 2048
28
+ DEFAULT_MAX_NEW_TOKENS = 1024
29
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
30
+
31
+ # Debugging: GPU check passed, loading model
32
+ if torch.cuda.is_available():
33
+ model_id = "INSAIT-Institute/BgGPT-Gemma-2-27B-IT-v1.0"
34
+ try:
35
+ print("Loading model...")
36
+ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto", token=HF_TOKEN)
37
+ print("Model loaded successfully!")
38
+
39
+ print("Loading tokenizer...")
40
+ tokenizer = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN)
41
+ print("Tokenizer loaded successfully!")
42
+ except Exception as e:
43
+ print(f"Error loading model or tokenizer: {e}")
44
+ raise e # Re-raise the error after logging it
45
+
46
+
47
+ @spaces.GPU
48
+ def generate(
49
+ message: str,
50
+ chat_history: list[tuple[str, str]],
51
+ max_new_tokens: int = 1024,
52
+ temperature: float = 0.6,
53
+ top_p: float = 0.9,
54
+ top_k: int = 50,
55
+ repetition_penalty: float = 1.2,
56
+ ) -> Iterator[str]:
57
+ print(f"Received message: {message}")
58
+ print(f"Chat history: {chat_history}")
59
+
60
+ conversation = []
61
+ for user, assistant in chat_history:
62
+ conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
63
+ conversation.append({"role": "user", "content": message})
64
+
65
+ try:
66
+ print("Tokenizing input...")
67
+ input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
68
+ print(f"Input tokenized: {input_ids.shape}")
69
+
70
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
71
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
72
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
73
+ print("Trimmed input tokens due to length.")
74
+
75
+ input_ids = input_ids.to(model.device)
76
+ print("Input moved to the model's device.")
77
+
78
+ streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
79
+ generate_kwargs = dict(
80
+ {"input_ids": input_ids},
81
+ streamer=streamer,
82
+ max_new_tokens=max_new_tokens,
83
+ do_sample=True,
84
+ top_p=top_p,
85
+ top_k=top_k,
86
+ temperature=temperature,
87
+ num_beams=1,
88
+ repetition_penalty=repetition_penalty,
89
+ )
90
+
91
+ print("Starting generation...")
92
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
93
+ t.start()
94
+ print("Thread started for model generation.")
95
+
96
+ outputs = []
97
+ for text in streamer:
98
+ outputs.append(text)
99
+ print(f"Generated text so far: {''.join(outputs)}")
100
+ yield "".join(outputs)
101
+
102
+ except Exception as e:
103
+ print(f"Error during generation: {e}")
104
+ raise e # Re-raise the error after logging it
105
+
106
+
107
+ def password_auth(password):
108
+ if password == PASSWORD:
109
+ return gr.update(visible=True), gr.update(visible=False)
110
+ else:
111
+ return gr.update(visible=False), gr.update(visible=True, value="Incorrect password. Try again.")
112
+
113
+ chat_interface = gr.ChatInterface(
114
+ fn=generate,
115
+ additional_inputs=[
116
+ gr.Slider(
117
+ label="Max new tokens",
118
+ minimum=1,
119
+ maximum=MAX_MAX_NEW_TOKENS,
120
+ step=1,
121
+ value=DEFAULT_MAX_NEW_TOKENS,
122
+ ),
123
+ gr.Slider(
124
+ label="Temperature",
125
+ minimum=0.1,
126
+ maximum=4.0,
127
+ step=0.1,
128
+ value=0.6,
129
+ ),
130
+ gr.Slider(
131
+ label="Top-p (nucleus sampling)",
132
+ minimum=0.05,
133
+ maximum=1.0,
134
+ step=0.05,
135
+ value=0.9,
136
+ ),
137
+ gr.Slider(
138
+ label="Top-k",
139
+ minimum=1,
140
+ maximum=1000,
141
+ step=1,
142
+ value=50,
143
+ ),
144
+ gr.Slider(
145
+ label="Repetition penalty",
146
+ minimum=1.0,
147
+ maximum=2.0,
148
+ step=0.05,
149
+ value=1.2,
150
+ ),
151
+ ],
152
+ stop_btn=None,
153
+ examples=[
154
+ ["Hello there! How are you doing?"],
155
+ ["Can you explain briefly to me what is the Python programming language?"],
156
+ ["Explain the plot of Cinderella in a sentence."],
157
+ ["How many hours does it take a man to eat a Helicopter?"],
158
+ ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
159
+ ],
160
+ )
161
+
162
+ # Debugging: Interface setup
163
+ print("Setting up interface...")
164
+
165
+ with gr.Blocks(css="style.css") as demo:
166
+ gr.Markdown(DESCRIPTION)
167
+
168
+ # Create login components
169
+ with gr.Row(visible=True) as login_area:
170
+ password_input = gr.Textbox(
171
+ label="Enter Password", type="password", placeholder="Password", show_label=True
172
+ )
173
+ login_btn = gr.Button("Submit")
174
+ incorrect_password_msg = gr.Markdown("Incorrect password. Try again.", visible=False)
175
+
176
+ # Main chat interface
177
+ with gr.Column(visible=False) as chat_area:
178
+ gr.Markdown(DESCRIPTION)
179
+ gr.DuplicateButton(
180
+ value="Duplicate Space for private use",
181
+ elem_id="duplicate-button",
182
+ visible=os.getenv("SHOW_DUPLICATE_BUTTON") == "1",
183
+ )
184
+ chat_interface.render()
185
+
186
+ # Bind login button to check password
187
+ login_btn.click(password_auth, inputs=password_input, outputs=[chat_area, incorrect_password_msg])
188
 
189
+ # Debugging: Starting queue and launching the demo
190
+ print("Launching demo...")
191
 
192
+ if __name__ == "__main__":
193
+ demo.queue(max_size=20).launch(share=True)