Sharmiaji commited on
Commit
f7d803f
·
verified ·
1 Parent(s): 7ba59b6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +408 -0
app.py ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Import necessary libraries
2
+ import streamlit as st
3
+ import mysql.connector
4
+ import bcrypt
5
+ import datetime
6
+ import re
7
+ import pytz
8
+ import time
9
+
10
+ # Import transformers library for GPT-2 model
11
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer
12
+ import torch
13
+
14
+ # Configure Streamlit page settings
15
+ icon='https://media.istockphoto.com/id/1413286466/vector/chat-bot-icon-robot-virtual-assistant-bot-vector-illustration.jpg?s=612x612&w=0&k=20&c=ZSG3eqGPDJgIgFUIuVxID64uVUF3eqM3LrrDWtaKses='
16
+ st.set_page_config(page_title='SJS-TRANS',page_icon=icon, menu_items={"about":'This streamlit application was developed by S.S.SHARMILA'})
17
+
18
+ # Connect to TiDB Cloud database
19
+ mydb = mysql.connector.connect(
20
+ host="gateway01.ap-southeast-1.prod.aws.tidbcloud.com", # Replace with your TiDB host URL
21
+ port=4000,
22
+ user="4TZNT9dUm5s8BMa.root", # Replace with your TiDB username
23
+ password="JZcYiJAMHjEYNN6Z" # Replace with your TiDB password
24
+ )
25
+ mycursor = mydb.cursor(buffered=True)
26
+
27
+ # Create 'GUVI_DB' database and use
28
+ mycursor.execute("CREATE DATABASE IF NOT EXISTS SJS-TRANS")
29
+ mycursor.execute('USE SJS-TRANS')
30
+
31
+ # Create 'users' table if it does not exist
32
+ mycursor.execute('''CREATE TABLE IF NOT EXISTS users (
33
+ id INT AUTO_INCREMENT PRIMARY KEY,
34
+ username VARCHAR(50) UNIQUE NOT NULL,
35
+ password VARCHAR(255) NOT NULL,
36
+ email VARCHAR(255) UNIQUE NOT NULL,
37
+ registered_date TIMESTAMP,
38
+ last_login TIMESTAMP
39
+ );''')
40
+
41
+ # Check if username exists in the database
42
+ def username_exists(username):
43
+ mycursor.execute("SELECT * FROM users WHERE username = %s", (username,))
44
+ return mycursor.fetchone() is not None
45
+
46
+ # Check if email exists in the database
47
+ def email_exists(email):
48
+ mycursor.execute("SELECT * FROM users WHERE email = %s", (email,))
49
+ return mycursor.fetchone() is not None
50
+
51
+ # Validate email format using regular expressions
52
+ def is_valid_email(email):
53
+ pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
54
+ return re.match(pattern, email) is not None
55
+
56
+ # Create a new user in the database
57
+ def create_user(username, password, email):
58
+ if username_exists(username):
59
+ return 'username_exists'
60
+
61
+ if email_exists(email):
62
+ return 'email_exists'
63
+
64
+ hashed_password = bcrypt.hashpw(password.encode('utf-8'), bcrypt.gensalt())
65
+ registered_date = datetime.datetime.now(pytz.timezone('Asia/Kolkata'))
66
+
67
+ # Insert user data into 'users' table
68
+ mycursor.execute(
69
+ "INSERT INTO users (username, password, email, registered_date) VALUES (%s, %s, %s, %s)",
70
+ (username, hashed_password, email, registered_date)
71
+ )
72
+ mydb.commit()
73
+ return 'success'
74
+
75
+ # Verify user credentials
76
+ def verify_user(username, password):
77
+ mycursor.execute("SELECT password FROM users WHERE username = %s", (username,))
78
+ record = mycursor.fetchone()
79
+ if record and bcrypt.checkpw(password.encode('utf-8'), record[0].encode('utf-8')):
80
+ # Update last login timestamp
81
+ mycursor.execute("UPDATE users SET last_login = %s WHERE username = %s", (datetime.datetime.now(pytz.timezone('Asia/Kolkata')), username))
82
+ mydb.commit()
83
+ return True
84
+ return False
85
+
86
+ # Reset user password
87
+ def reset_password(username, new_password):
88
+ hashed_password = bcrypt.hashpw(new_password.encode('utf-8'), bcrypt.gensalt())
89
+ # Update password in 'users' table
90
+ mycursor.execute(
91
+ "UPDATE users SET password = %s WHERE username = %s",
92
+ (hashed_password, username)
93
+ )
94
+ mydb.commit()
95
+
96
+ # Session state management
97
+ if 'sign_up_successful' not in st.session_state:
98
+ st.session_state.sign_up_successful = False
99
+ if 'login_successful' not in st.session_state:
100
+ st.session_state.login_successful = False
101
+ if 'reset_password' not in st.session_state:
102
+ st.session_state.reset_password = False
103
+ if 'username' not in st.session_state:
104
+ st.session_state.username = ''
105
+ if 'current_page' not in st.session_state:
106
+ st.session_state.current_page = 'login'
107
+
108
+
109
+ # Login form
110
+ def login():
111
+ with st.form(key='login', clear_on_submit=True):
112
+ st.subheader(':blue[**Login**]')
113
+ st.write("Enter your username and password below.")
114
+
115
+ # Input fields for username and password
116
+ username = st.text_input(label='Username', placeholder='Enter Your Username')
117
+ password = st.text_input(label='Password', placeholder='Enter Your Password', type='password')
118
+
119
+ if st.form_submit_button('Login'):
120
+ if not username or not password:
121
+ st.error("Please fill out all fields.")
122
+ elif verify_user(username, password):
123
+ st.session_state.login_successful = True
124
+ st.session_state.username = username
125
+ st.session_state.current_page = 'home'
126
+ st.rerun()
127
+ else:
128
+ st.error("Incorrect username or password. If you don't have an account, please sign up.")
129
+
130
+ # Display sign-up and reset password button
131
+ if not st.session_state.login_successful:
132
+ c1, c2 = st.columns(2)
133
+ with c1:
134
+ st.write(":red[New user?]")
135
+ if st.button('Sign Up'):
136
+ st.session_state.current_page = 'sign_up'
137
+ st.rerun()
138
+ with c2:
139
+ st.write(":red[Forgot Password?]")
140
+ if st.button('Reset Password'):
141
+ st.session_state.current_page = 'reset_password'
142
+ st.rerun()
143
+
144
+
145
+ # Sign-up form
146
+ def signup():
147
+ with st.form(key='signup', clear_on_submit=True):
148
+ st.subheader(':blue[**Sign Up**]')
149
+ st.write("Enter the required fields to create a new account.")
150
+
151
+ # Input fields for email, username, and password
152
+ email = st.text_input(label='Email', placeholder='Enter Your Email')
153
+ username = st.text_input(label='Username', placeholder='Enter Your Username')
154
+ password = st.text_input(label='Password', placeholder='Enter Your Password', type='password')
155
+ re_password = st.text_input(label='Confirm Password', placeholder='Confirm Your Password', type='password')
156
+
157
+
158
+ if st.form_submit_button('Sign Up'):
159
+ if not email or not username or not password or not re_password:
160
+ st.error("Please fill out all fields.")
161
+ elif not is_valid_email(email):
162
+ st.error("Please enter a valid email address.")
163
+ elif len(password) <= 3:
164
+ st.error("Password too short")
165
+ elif password != re_password:
166
+ st.error("Passwords do not match. Please re-enter.")
167
+ else:
168
+ result = create_user(username, password, email)
169
+ if result == 'username_exists':
170
+ st.error("Username already registered. Please use a different username.")
171
+ elif result == 'email_exists':
172
+ st.error("Email already registered. Please use a different email.")
173
+ elif result == 'success':
174
+ st.success(f"Username {username} created successfully! Please login.")
175
+ st.session_state.sign_up_successful = True
176
+ else:
177
+ st.error("Failed to create user. Please try again later.")
178
+
179
+ if st.session_state.sign_up_successful:
180
+ if st.button('Go to Login'):
181
+ st.session_state.current_page = 'login'
182
+ st.rerun()
183
+
184
+
185
+ # Reset password form
186
+ def reset_password_page():
187
+ with st.form(key='reset_password', clear_on_submit=True):
188
+ st.subheader(':blue[Reset Password]')
189
+ st.write("Enter your username and new password below.")
190
+
191
+ # Input fields for username and new password
192
+ username = st.text_input(label='Username', value='')
193
+ new_password = st.text_input(label='New Password', type='password')
194
+ re_password = st.text_input(label='Confirm New Password', type='password')
195
+
196
+ if st.form_submit_button('Reset Password'):
197
+ if not username:
198
+ st.error("Please enter your username.")
199
+ elif not username_exists(username):
200
+ st.error("Username not found. Please enter a valid username.")
201
+ elif not new_password or not re_password:
202
+ st.error("Please fill out all fields.")
203
+ elif len(new_password) <= 3:
204
+ st.error("Password too short")
205
+ elif new_password != re_password:
206
+ st.error("Passwords do not match. Please re-enter.")
207
+ else:
208
+ reset_password(username, new_password)
209
+ st.success("Password reset successfully. Please login with your new password.")
210
+ st.session_state.current_page = 'login'
211
+
212
+ # Button to return to login page
213
+ st.write('Return to Login page')
214
+ if st.button('Login'):
215
+ st.session_state.current_page = 'login'
216
+ st.rerun()
217
+
218
+
219
+
220
+ # HTML and CSS for the animated title and Disclaimer
221
+ html_content = """
222
+ <!DOCTYPE html>
223
+ <html lang="en">
224
+ <head>
225
+ <meta charset="UTF-8">
226
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
227
+ <style>
228
+ body {
229
+ font-family: Arial, sans-serif;
230
+ display: flex;
231
+ align-items: center;
232
+ height: 100vh;
233
+ margin: 0;
234
+ background-color: #0E1117;
235
+ color: #ffffff;
236
+ }
237
+ .title {
238
+ font-size: 3rem;
239
+ }
240
+ .animated {
241
+ display: inline-block;
242
+ background: linear-gradient(90deg, #ff5733, #33ff57, #3357ff, #ff33a1);
243
+ background-size: 400% 400%;
244
+ -webkit-background-clip: text;
245
+ -webkit-text-fill-color: transparent;
246
+ animation: gradient 8s ease infinite;
247
+ }
248
+ @keyframes gradient {
249
+ 0% {
250
+ background-position: 0% 50%;
251
+ }
252
+ 50% {
253
+ background-position: 100% 50%;
254
+ }
255
+ 100% {
256
+ background-position: 0% 50%;
257
+ }
258
+ }
259
+ </style>
260
+ </head>
261
+ <body>
262
+ <h2 class="title">
263
+ <span class="animated">GUVI GPT - Text Generator</span>
264
+ </h2>
265
+ </body>
266
+ </html>
267
+ """
268
+
269
+
270
+ # CSS style for the running ticker effect
271
+ ticker_style = """
272
+ <style>
273
+ .ticker-wrap {
274
+ overflow: hidden;
275
+ position: relative;
276
+ box-sizing: border-box;
277
+ padding: 10px;
278
+ background-color: #0E1117;
279
+ color: #FAFAFA;
280
+ font-size: 18px;
281
+ white-space: nowrap;
282
+ }
283
+ .ticker-item {
284
+ display: inline-block;
285
+ padding-right: 30px;
286
+ animation: ticker-slide 15s linear infinite;
287
+ }
288
+ @keyframes ticker-slide {
289
+ 0% {
290
+ transform: translateX(100%);
291
+ }
292
+ 100% {
293
+ transform: translateX(-100%);
294
+ }
295
+ }
296
+ </style>
297
+ """
298
+
299
+
300
+ img='https://media0.giphy.com/media/v1.Y2lkPTc5MGI3NjExOGttcXJnNXZ5azl5bnNrbDk1bnZ4eTJkeWNtbGJhM2I4ZDZheWplMiZlcD12MV9pbnRlcm5hbF9naWZfYnlfaWQmY3Q9Zw/5k5vZwRFZR5aZeniqb/giphy.webp'
301
+
302
+ disclaimer_message = "Disclaimer: Model data sourced from various web articles. Performance may vary based on data quality and relevance."
303
+
304
+
305
+
306
+ # Load the fine-tuned model and tokenizer
307
+ model_name_or_path = "./fine_tuned_model"
308
+ model = GPT2LMHeadModel.from_pretrained(model_name_or_path)
309
+ tokenizer = GPT2Tokenizer.from_pretrained(model_name_or_path)
310
+
311
+ # Set the pad_token to eos_token if it's not already set
312
+ if tokenizer.pad_token is None:
313
+ tokenizer.pad_token = tokenizer.eos_token
314
+
315
+ # Move the model to GPU if available
316
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
317
+ model.to(device)
318
+
319
+
320
+ # Define the text generation function
321
+ def generate_text(model, tokenizer, seed_text, max_length=100, temperature=0.01, num_return_sequences=1):
322
+ # Tokenize the input text with padding
323
+ inputs = tokenizer(seed_text, return_tensors='pt', padding=True, truncation=True)
324
+
325
+ input_ids = inputs['input_ids'].to(device)
326
+ attention_mask = inputs['attention_mask'].to(device)
327
+
328
+ # Generate text
329
+ with torch.no_grad():
330
+ output = model.generate(
331
+ input_ids,
332
+ attention_mask=attention_mask,
333
+ max_length=max_length,
334
+ temperature=temperature,
335
+ num_return_sequences=num_return_sequences,
336
+ do_sample=True,
337
+ top_k=50,
338
+ top_p=0.1,
339
+ pad_token_id=tokenizer.eos_token_id # Ensure padding token is set to eos_token_id
340
+ )
341
+
342
+ # Decode the generated text
343
+
344
+ generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
345
+
346
+ for word in generated_text.split():
347
+ yield word + " "
348
+ time.sleep(0.1)
349
+
350
+
351
+ #home page
352
+ def home_page():
353
+
354
+ with st.sidebar:
355
+
356
+ st.title(f"Welcome, {st.session_state.username}!")
357
+ st.image(img,use_column_width=True)
358
+
359
+ st.markdown('<br>',unsafe_allow_html=True)
360
+
361
+ st.write("### Example Prompts")
362
+ st.markdown(''' Guvi is an <br> Founders of guvi''',unsafe_allow_html=True)
363
+
364
+ st.markdown('<br>',unsafe_allow_html=True)
365
+
366
+ max=st.slider('Select MAX words',10,250)
367
+
368
+ if st.button("Logout"):
369
+ st.session_state.clear()
370
+ st.session_state.current_page = 'login'
371
+ st.rerun()
372
+
373
+
374
+
375
+ # Display the animated title in Streamlit
376
+ st.markdown(html_content, unsafe_allow_html=True)
377
+ st.markdown(ticker_style, unsafe_allow_html=True)
378
+ st.markdown(f'<div class="ticker-wrap"><div class="ticker-item">{disclaimer_message}</div></div>', unsafe_allow_html=True)
379
+
380
+
381
+ if "messages" not in st.session_state:
382
+ st.session_state.messages = []
383
+
384
+ if prompt := st.chat_input("What up?"):
385
+
386
+ with st.chat_message("user"):
387
+ st.markdown(prompt)
388
+
389
+ st.session_state.messages.append({"role": "user", "content": prompt})
390
+
391
+ with st.chat_message("assistant"):
392
+
393
+ response = st.write_stream(generate_text(model, tokenizer, seed_text=prompt, max_length=max, temperature=0.01, num_return_sequences=1))
394
+
395
+ st.session_state.messages.append({"role": "assistant", "content": response})
396
+
397
+
398
+
399
+
400
+ # Display appropriate page based on session state
401
+ if st.session_state.current_page == 'home':
402
+ home_page()
403
+ elif st.session_state.current_page == 'login':
404
+ login()
405
+ elif st.session_state.current_page == 'sign_up':
406
+ signup()
407
+ elif st.session_state.current_page == 'reset_password':
408
+ reset_password_page()