acecalisto3 commited on
Commit
7321c1c
·
1 Parent(s): 1ce0e25

Automated update: Implement model swapping

Browse files
Files changed (2) hide show
  1. CEEMEESEEK +1 -0
  2. app.py +18 -1229
CEEMEESEEK ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 1ce0e259eecb0685bbb28e98ce87ffe73b71e776
app.py CHANGED
@@ -1,1176 +1,29 @@
1
- import os
2
- import time
3
- import hashlib
4
- import logging
5
- import streamlit as st
6
- import datetime
7
- import csv
8
- import threading
9
- import re
10
- import unittest
11
- from urllib.parse import urlparse
12
- import spaces
13
 
14
- import pandas as pd
15
- from selenium import webdriver
16
- from selenium.webdriver.chrome.service import Service
17
- from selenium.webdriver.chrome.options import Options
18
- from selenium.webdriver.common.by import By
19
- from selenium.webdriver.support.ui import WebDriverWait
20
- from selenium.webdriver.support import expected_conditions as EC
21
- from selenium.common.exceptions import (
22
- TimeoutException,
23
 
24
- NoSuchElementException,
25
- StaleElementReferenceException,
26
- )
27
- from webdriver_manager.chrome import ChromeDriverManager
28
 
29
- from transformers import AutoTokenizer, OpenLlamaForCausalLM, pipeline
30
- import gradio as gr
31
- import xml.etree.ElementTree as ET
32
- import torch
33
- import mysql.connector
34
- from mysql.connector import errorcode, pooling
35
- import nltk
36
- import importlib
37
 
38
- st.title("CEEMEESEEK with Model Selection")
39
 
40
- model_option = st.selectbox("Select a Model", ["Falcon", "Flan-T5", "Other Model"]) # Add your model names
41
 
42
- if model_option == "Falcon":
43
- model_module = importlib.import_module("model_falcon") # Assuming you create model_falcon.py
44
- model = model_module.load_falcon_model()
45
- elif model_option == "Flan-T5":
46
- model_module = importlib.import_module("model_flan_t5")
47
- model = model_module.load_flan_t5_model()
48
-
49
- HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
50
-
51
- if not HUGGINGFACE_TOKEN:
52
- raise ValueError("HUGGINGFACE_TOKEN is not set in the environment variables.")
53
- add_to_git_credential=True
54
- login(token=HUGGINGFACE_TOKEN, add_to_git_credential=True)
55
-
56
-
57
- # Load environment variables from .env file
58
- load_dotenv()
59
-
60
- # Configure logging
61
- logging.basicConfig(
62
- level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
63
- )
64
-
65
- # Define constants
66
- DEFAULT_FILE_PATH = "scraped_data"
67
- PURPOSE = (
68
- "You go to Culvers sites, you continuously seek changes on them since your last observation. "
69
- "Anything new that gets logged and dumped into csv, stored in your log folder at user/app/scraped_data."
70
- )
71
-
72
- # Global variables for task management
73
- HISTORY = []
74
- CURRENT_TASK = None
75
- STOP_THREADS = False # Flag to stop scraping threads
76
-
77
- # Database Pooling Configuration
78
- DB_POOL_NAME = "mypool"
79
- DB_POOL_SIZE = 5 # Adjust based on expected load
80
-
81
- try:
82
- dbconfig = {
83
- "host": os.getenv("DB_HOST"),
84
- "user": os.getenv("DB_USER"),
85
- "password": os.getenv("DB_PASSWORD"),
86
- "database": os.getenv("DB_NAME"),
87
- }
88
- connection_pool = mysql.connector.pooling.MySQLConnectionPool(
89
- pool_name=DB_POOL_NAME,
90
- pool_size=DB_POOL_SIZE,
91
- pool_reset_session=True,
92
- **dbconfig
93
- )
94
- logging.info("Database connection pool created successfully.")
95
- except mysql.connector.Error as err:
96
- logging.warning(f"Database connection pool creation failed: {err}")
97
- connection_pool = None # Will use CSV as fallback
98
-
99
- # Function to get a database connection from the pool
100
- def get_db_connection():
101
- """
102
- Retrieves a connection from the pool. Returns None if pool is not available.
103
- """
104
- if connection_pool:
105
- try:
106
- connection = connection_pool.get_connection()
107
- if connection.is_connected():
108
- return connection
109
- except mysql.connector.Error as err:
110
- logging.error(f"Error getting connection from pool: {err}")
111
- return None
112
-
113
- # Initialize Database: Create tables and indexes
114
- def initialize_database():
115
- """
116
- Initializes the database by creating necessary tables and indexes if they do not exist.
117
- """
118
- connection = get_db_connection()
119
- if connection is None:
120
- logging.info("Database initialization skipped. Using CSV storage.")
121
- return
122
-
123
- cursor = connection.cursor()
124
- try:
125
- # Create table for scraped data
126
- create_scraped_data_table = """
127
- CREATE TABLE IF NOT EXISTS scraped_data (
128
- id INT AUTO_INCREMENT PRIMARY KEY,
129
- url VARCHAR(255) NOT NULL,
130
- content_hash VARCHAR(64) NOT NULL,
131
- change_detected DATETIME NOT NULL
132
- )
133
- """
134
- cursor.execute(create_scraped_data_table)
135
- logging.info("Table 'scraped_data' is ready.")
136
-
137
- # Create indexes for performance
138
- create_index_url = "CREATE INDEX IF NOT EXISTS idx_url ON scraped_data(url)"
139
- create_index_change = "CREATE INDEX IF NOT EXISTS idx_change_detected ON scraped_data(change_detected)"
140
- cursor.execute(create_index_url)
141
- cursor.execute(create_index_change)
142
- logging.info("Indexes on 'url' and 'change_detected' columns created.")
143
-
144
- # Create table for action logs
145
- create_action_logs_table = """
146
- CREATE TABLE IF NOT EXISTS action_logs (
147
- id INT AUTO_INCREMENT PRIMARY KEY,
148
- action VARCHAR(255) NOT NULL,
149
- timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
150
- )
151
- """
152
- cursor.execute(create_action_logs_table)
153
- logging.info("Table 'action_logs' is ready.")
154
-
155
- except mysql.connector.Error as err:
156
- logging.error(f"Error initializing database: {err}")
157
- finally:
158
- cursor.close()
159
- connection.close()
160
- logging.info("Database initialization complete.")
161
-
162
- # Function to create WebDriver
163
- def create_driver(options: Options) -> webdriver.Chrome:
164
- """
165
- Initializes and returns a Selenium Chrome WebDriver instance.
166
- """
167
- try:
168
- driver = webdriver.Chrome(
169
- service=Service(ChromeDriverManager().install()), options=options
170
- )
171
- logging.info("ChromeDriver initialized successfully.")
172
- return driver
173
- except Exception as exception:
174
- logging.error(f"Error initializing ChromeDriver: {exception}")
175
- return None
176
-
177
- # Function to log changes to CSV
178
- def log_to_csv(storage_location: str, url: str, content_hash: str, change_detected: str):
179
- """
180
- Logs the change to a CSV file in the storage_location.
181
- """
182
- try:
183
- os.makedirs(storage_location, exist_ok=True)
184
- csv_file_path = os.path.join(storage_location, f"{urlparse(url).hostname}_changes.csv")
185
- file_exists = os.path.isfile(csv_file_path)
186
-
187
- with open(csv_file_path, "a", newline="", encoding="utf-8") as csvfile:
188
- fieldnames = ["date", "time", "url", "content_hash", "change"]
189
- writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
190
- if not file_exists:
191
- writer.writeheader()
192
- writer.writerow(
193
- {
194
- "date": change_detected.split()[0],
195
- "time": change_detected.split()[1],
196
- "url": url,
197
- "content_hash": content_hash,
198
- "change": "Content changed",
199
- }
200
- )
201
- logging.info(f"Change detected at {url} on {change_detected} and logged to CSV.")
202
- except Exception as e:
203
- logging.error(f"Error logging data to CSV: {e}")
204
-
205
- # Function to get initial observation
206
- def get_initial_observation(
207
- driver: webdriver.Chrome, url: str, content_type: str, selector: str = None
208
- ) -> str:
209
- """
210
- Retrieves the initial content from the URL and returns its MD5 hash.
211
- """
212
- try:
213
- driver.get(url)
214
- WebDriverWait(driver, 10).until(EC.presence_of_element_located((By.TAG_NAME, "body")))
215
- time.sleep(2) # Additional wait for dynamic content
216
-
217
- if content_type == "text":
218
- initial_content = driver.page_source
219
- elif content_type == "media":
220
- if selector:
221
- try:
222
- elements = WebDriverWait(driver, 5).until(
223
- EC.presence_of_all_elements_located((By.CSS_SELECTOR, selector))
224
- )
225
- initial_content = [element.get_attribute("src") for element in elements]
226
- except TimeoutException:
227
- logging.warning(f"Timeout waiting for media elements with selector '{selector}' on {url}")
228
- initial_content = []
229
- else:
230
- elements = driver.find_elements(By.TAG_NAME, "img")
231
- initial_content = [element.get_attribute("src") for element in elements]
232
- else:
233
- initial_content = driver.page_source
234
-
235
- initial_hash = hashlib.md5(str(initial_content).encode("utf-8")).hexdigest()
236
- logging.info(f"Initial hash for {url}: {initial_hash}")
237
- return initial_hash
238
- except Exception as exception:
239
- logging.error(f"Error accessing {url}: {exception}")
240
- return None
241
-
242
- # Function to monitor URLs for changes
243
- def monitor_urls(
244
- storage_location: str,
245
- urls: list,
246
- scrape_interval: int,
247
- content_type: str,
248
- selector: str = None,
249
- progress: gr.Progress = None
250
- ):
251
- """
252
- Monitors the specified URLs for changes and logs any detected changes to the database or CSV.
253
- """
254
- global HISTORY, STOP_THREADS
255
- previous_hashes = {url: "" for url in urls}
256
-
257
- options = Options()
258
- options.add_argument("--headless")
259
- options.add_argument("--no-sandbox")
260
- options.add_argument("--disable-dev-shm-usage")
261
-
262
- driver = create_driver(options)
263
- if driver is None:
264
- logging.error("WebDriver could not be initialized. Exiting monitor.")
265
- return
266
-
267
- try:
268
- while not STOP_THREADS:
269
- for url in urls:
270
- if STOP_THREADS:
271
- break
272
- try:
273
- driver.get(url)
274
- WebDriverWait(driver, 10).until(EC.presence_of_element_located((By.TAG_NAME, "body")))
275
- time.sleep(2) # Additional wait for dynamic content
276
-
277
- if content_type == "text":
278
- current_content = driver.page_source
279
- elif content_type == "media":
280
- if selector:
281
- try:
282
- elements = WebDriverWait(driver, 5).until(
283
- EC.presence_of_all_elements_located((By.CSS_SELECTOR, selector))
284
- )
285
- current_content = [element.get_attribute("src") for element in elements]
286
- except TimeoutException:
287
- logging.warning(f"Timeout waiting for media elements with selector '{selector}' on {url}")
288
- current_content = []
289
- else:
290
- elements = driver.find_elements(By.TAG_NAME, "img")
291
- current_content = [element.get_attribute("src") for element in elements]
292
- else:
293
- current_content = driver.page_source
294
-
295
- current_hash = hashlib.md5(str(current_content).encode("utf-8")).hexdigest()
296
- if current_hash != previous_hashes[url]:
297
- previous_hashes[url] = current_hash
298
- date_time_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
299
- HISTORY.append(f"Change detected at {url} on {date_time_str}")
300
-
301
- # Attempt to log to database
302
- connection = get_db_connection()
303
- if connection:
304
- try:
305
- cursor = connection.cursor()
306
- insert_query = """
307
- INSERT INTO scraped_data (url, content_hash, change_detected)
308
- VALUES (%s, %s, %s)
309
- """
310
- cursor.execute(insert_query, (url, current_hash, date_time_str))
311
- connection.commit()
312
- logging.info(f"Change detected at {url} on {date_time_str} and logged to database.")
313
- except mysql.connector.Error as err:
314
- logging.error(f"Error inserting data into database: {err}")
315
- # Fallback to CSV
316
- log_to_csv(storage_location, url, current_hash, date_time_str)
317
- finally:
318
- cursor.close()
319
- connection.close()
320
- else:
321
- # Fallback to CSV
322
- log_to_csv(storage_location, url, current_hash, date_time_str)
323
-
324
- # Update progress
325
- if progress:
326
- progress(1)
327
- except (
328
- NoSuchElementException,
329
- StaleElementReferenceException,
330
- TimeoutException,
331
- Exception,
332
- ) as e:
333
- logging.error(f"Error accessing {url}: {e}")
334
- if progress:
335
- progress(1)
336
- time.sleep(scrape_interval * 60) # Wait for the next scrape interval
337
- finally:
338
- driver.quit()
339
- logging.info("ChromeDriver session ended.")
340
-
341
- # Function to start scraping
342
- def start_scraping(
343
- storage_location: str,
344
- urls: str,
345
- scrape_interval: int,
346
- content_type: str,
347
- selector: str = None,
348
- progress: gr.Progress = None
349
- ) -> str:
350
- """
351
- Starts the scraping process in a separate thread with progress indication.
352
- """
353
- global CURRENT_TASK, HISTORY, STOP_THREADS
354
-
355
- if STOP_THREADS:
356
- STOP_THREADS = False # Reset the flag if previously stopped
357
-
358
- url_list = [url.strip() for url in urls.split(",") if url.strip()]
359
- CURRENT_TASK = f"Monitoring URLs: {', '.join(url_list)}"
360
- HISTORY.append(f"Task started: {CURRENT_TASK}")
361
- logging.info(f"Task started: {CURRENT_TASK}")
362
-
363
- # Initialize database tables
364
- initialize_database()
365
-
366
- # Log initial observations
367
- def log_initial_observations():
368
- options = Options()
369
- options.add_argument("--headless")
370
- options.add_argument("--no-sandbox")
371
- options.add_argument("--disable-dev-shm-usage")
372
-
373
- driver = create_driver(options)
374
- if driver is None:
375
- return
376
-
377
- for url in url_list:
378
- if STOP_THREADS:
379
- break
380
- try:
381
- initial_hash = get_initial_observation(driver, url, content_type, selector)
382
- if initial_hash:
383
- date_time_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
384
- HISTORY.append(f"Initial observation at {url}: {initial_hash}")
385
-
386
- # Attempt to log to database
387
- connection = get_db_connection()
388
- if connection:
389
- try:
390
- cursor = connection.cursor()
391
- insert_query = """
392
- INSERT INTO scraped_data (url, content_hash, change_detected)
393
- VALUES (%s, %s, %s)
394
- """
395
- cursor.execute(insert_query, (url, initial_hash, date_time_str))
396
- connection.commit()
397
- logging.info(f"Initial observation logged for {url} in database.")
398
- except mysql.connector.Error as err:
399
- logging.error(f"Error inserting initial observation into database: {err}")
400
- # Fallback to CSV
401
- log_to_csv(storage_location, url, initial_hash, date_time_str)
402
- finally:
403
- cursor.close()
404
- connection.close()
405
- else:
406
- # Fallback to CSV
407
- log_to_csv(storage_location, url, initial_hash, date_time_str)
408
- except Exception as e:
409
- HISTORY.append(f"Error accessing {url}: {e}")
410
- logging.error(f"Error accessing {url}: {e}")
411
- driver.quit()
412
-
413
- # Start logging initial observations
414
- initial_thread = threading.Thread(target=log_initial_observations, daemon=True)
415
- initial_thread.start()
416
-
417
- # Start the monitoring thread with progress
418
- monitor_thread = threading.Thread(
419
- target=monitor_urls,
420
- args=(storage_location, url_list, scrape_interval, content_type, selector, progress),
421
- daemon=True,
422
- )
423
- monitor_thread.start()
424
- logging.info("Started scraping thread.")
425
- return f"Started scraping {', '.join(url_list)} every {scrape_interval} minutes."
426
-
427
- # Function to stop scraping
428
- def stop_scraping() -> str:
429
- """
430
- Stops all ongoing scraping threads.
431
- """
432
- global STOP_THREADS
433
- STOP_THREADS = True
434
- HISTORY.append("Scraping stopped by user.")
435
- logging.info("Scraping stop signal sent.")
436
- return "Scraping has been stopped."
437
-
438
- # Function to display CSV content from MySQL or CSV
439
- def display_csv(storage_location: str, url: str) -> str:
440
- """
441
- Fetches and returns the scraped data for a given URL from the MySQL database or CSV.
442
- """
443
- try:
444
- connection = get_db_connection()
445
- if connection:
446
- try:
447
- cursor = connection.cursor(dictionary=True)
448
- query = "SELECT * FROM scraped_data WHERE url = %s ORDER BY change_detected DESC"
449
- cursor.execute(query, (url,))
450
- results = cursor.fetchall()
451
-
452
- if not results:
453
- return "No data available for the selected URL."
454
-
455
- df = pd.DataFrame(results)
456
- cursor.close()
457
- connection.close()
458
- return df.to_string(index=False)
459
- except mysql.connector.Error as err:
460
- logging.error(f"Error fetching data from database: {err}")
461
- # Fallback to CSV
462
- else:
463
- logging.info("No database connection. Fetching data from CSV.")
464
-
465
- # Fallback to CSV
466
- hostname = urlparse(url).hostname
467
- csv_path = os.path.join(storage_location, f"{hostname}_changes.csv")
468
- if os.path.exists(csv_path):
469
- df = pd.read_csv(csv_path)
470
- return df.to_string(index=False)
471
- else:
472
- return "No data available."
473
-
474
- except Exception as e:
475
- logging.error(f"Error fetching data for {url}: {e}")
476
- return f"Error fetching data for {url}: {e}"
477
-
478
- # Function to generate RSS feed from MySQL or CSV data
479
- def generate_rss_feed(storage_location: str, url: str) -> str:
480
- """
481
- Generates an RSS feed for the latest changes detected on a given URL from the MySQL database or CSV.
482
- """
483
- try:
484
- connection = get_db_connection()
485
- rss_feed = ""
486
-
487
- if connection:
488
- try:
489
- cursor = connection.cursor(dictionary=True)
490
- query = "SELECT * FROM scraped_data WHERE url = %s ORDER BY change_detected DESC LIMIT 10"
491
- cursor.execute(query, (url,))
492
- results = cursor.fetchall()
493
-
494
- if not results:
495
- return "No changes detected to include in RSS feed."
496
-
497
- # Create the root RSS element
498
- rss = ET.Element("rss", version="2.0")
499
- channel = ET.SubElement(rss, "channel")
500
-
501
- # Add channel elements
502
- title = ET.SubElement(channel, "title")
503
- title.text = f"RSS Feed for {urlparse(url).hostname}"
504
-
505
- link = ET.SubElement(channel, "link")
506
- link.text = url
507
-
508
- description = ET.SubElement(channel, "description")
509
- description.text = "Recent changes detected on the website."
510
-
511
- # Add items to the feed
512
- for row in results:
513
- item = ET.SubElement(channel, "item")
514
-
515
- item_title = ET.SubElement(item, "title")
516
- item_title.text = f"Change detected at {row['url']}"
517
-
518
- item_link = ET.SubElement(item, "link")
519
- item_link.text = row["url"]
520
-
521
- item_description = ET.SubElement(item, "description")
522
- item_description.text = f"Content changed on {row['change_detected']}"
523
-
524
- pub_date = ET.SubElement(item, "pubDate")
525
- pub_date.text = datetime.datetime.strptime(
526
- str(row['change_detected']), "%Y-%m-%d %H:%M:%S"
527
- ).strftime("%a, %d %b %Y %H:%M:%S +0000")
528
-
529
- # Generate the XML string
530
- rss_feed = ET.tostring(rss, encoding="utf-8", method="xml").decode("utf-8")
531
- cursor.close()
532
- connection.close()
533
- return rss_feed
534
- except mysql.connector.Error as err:
535
- logging.error(f"Error fetching data from database: {err}")
536
- # Fallback to CSV
537
- else:
538
- logging.info("No database connection. Generating RSS feed from CSV.")
539
-
540
- # Fallback to CSV
541
- hostname = urlparse(url).hostname
542
- csv_path = os.path.join(storage_location, f"{hostname}_changes.csv")
543
- if os.path.exists(csv_path):
544
- df = pd.read_csv(csv_path).tail(10)
545
- if df.empty:
546
- return "No changes detected to include in RSS feed."
547
-
548
- # Create the root RSS element
549
- rss = ET.Element("rss", version="2.0")
550
- channel = ET.SubElement(rss, "channel")
551
-
552
- # Add channel elements
553
- title = ET.SubElement(channel, "title")
554
- title.text = f"RSS Feed for {hostname}"
555
-
556
- link = ET.SubElement(channel, "link")
557
- link.text = url
558
-
559
- description = ET.SubElement(channel, "description")
560
- description.text = "Recent changes detected on the website."
561
-
562
- # Add items to the feed
563
- for _, row in df.iterrows():
564
- item = ET.SubElement(channel, "item")
565
-
566
- item_title = ET.SubElement(item, "title")
567
- item_title.text = f"Change detected at {row['url']}"
568
-
569
- item_link = ET.SubElement(item, "link")
570
- item_link.text = row["url"]
571
-
572
- item_description = ET.SubElement(item, "description")
573
- item_description.text = f"Content changed on {row['date']} at {row['time']}"
574
-
575
- pub_date = ET.SubElement(item, "pubDate")
576
- pub_date.text = datetime.datetime.strptime(
577
- f"{row['date']} {row['time']}", "%Y-%m-%d %H:%M:%S"
578
- ).strftime("%a, %d %b %Y %H:%M:%S +0000")
579
-
580
- # Generate the XML string
581
- rss_feed = ET.tostring(rss, encoding="utf-8", method="xml").decode("utf-8")
582
- return rss_feed
583
- else:
584
- return "No data available."
585
-
586
- except Exception as e:
587
- logging.error(f"Error generating RSS feed for {url}: {e}")
588
- return f"Error generating RSS feed for {url}: {e}"
589
-
590
- # Function to parse user commands using spaCy
591
- def parse_command(message: str) -> tuple:
592
- """
593
- Parses the user message using spaCy to identify if it contains a command.
594
- Returns the command and its parameters if found, else (None, None).
595
- """
596
- doc = nlp(message.lower())
597
- command = None
598
- params = {}
599
-
600
- # Define command patterns
601
- if "filter" in message.lower():
602
- # Example: "Filter apples, oranges in column Description"
603
- match = re.search(r"filter\s+([\w\s,]+)\s+in\s+column\s+(\w+)", message, re.IGNORECASE)
604
- if match:
605
- words = [word.strip() for word in match.group(1).split(",")]
606
- column = match.group(2)
607
- command = "filter"
608
- params = {"words": words, "column": column}
609
-
610
- elif "sort" in message.lower():
611
- # Example: "Sort Price ascending"
612
- match = re.search(r"sort\s+(\w+)\s+(ascending|descending)", message, re.IGNORECASE)
613
- if match:
614
- column = match.group(1)
615
- order = match.group(2)
616
- command = "sort"
617
- params = {"column": column, "order": order}
618
-
619
- elif "export to csv as" in message.lower():
620
- # Example: "Export to CSV as filtered_data.csv"
621
- match = re.search(r"export\s+to\s+csv\s+as\s+([\w\-]+\.csv)", message, re.IGNORECASE)
622
- if match:
623
- filename = match.group(1)
624
- command = "export"
625
- params = {"filename": filename}
626
-
627
- elif "log action" in message.lower():
628
- # Example: "Log action Filtered data for specific fruits"
629
- match = re.search(r"log\s+action\s+(.+)", message, re.IGNORECASE)
630
- if match:
631
- action = match.group(1)
632
- command = "log"
633
- params = {"action": action}
634
-
635
- return command, params
636
-
637
- # Function to execute parsed commands
638
- def execute_command(command: str, params: dict) -> str:
639
- """
640
- Executes the corresponding function based on the command and parameters.
641
- """
642
- if command == "filter":
643
- words = params["words"]
644
- column = params["column"]
645
- return filter_data(column, words)
646
- elif command == "sort":
647
- column = params["column"]
648
- order = params["order"]
649
- return sort_data(column, order)
650
- elif command == "export":
651
- filename = params["filename"]
652
- return export_csv(filename)
653
- elif command == "log":
654
- action = params["action"]
655
- return log_action(action)
656
- else:
657
- return "Unknown command."
658
-
659
- # Data Manipulation Functions
660
- def filter_data(column: str, words: list) -> str:
661
- """
662
- Filters the scraped data to include only rows where the specified column contains the given words.
663
- Saves the filtered data to a new CSV file.
664
- """
665
- try:
666
- storage_location = DEFAULT_FILE_PATH
667
-
668
- connection = get_db_connection()
669
- if connection:
670
- try:
671
- cursor = connection.cursor(dictionary=True)
672
- # Fetch all data
673
- query = "SELECT * FROM scraped_data"
674
- cursor.execute(query)
675
- results = cursor.fetchall()
676
-
677
- if not results:
678
- return "No data available to filter."
679
-
680
- df = pd.DataFrame(results)
681
- # Create a regex pattern to match any of the words
682
- pattern = '|'.join(words)
683
- if column not in df.columns:
684
- return f"Column '{column}' does not exist in the data."
685
-
686
- filtered_df = df[df[column].astype(str).str.contains(pattern, case=False, na=False)]
687
-
688
- if filtered_df.empty:
689
- return f"No records found with words {words} in column '{column}'."
690
-
691
- # Save the filtered data to a new CSV
692
- timestamp = int(time.time())
693
- filtered_csv = os.path.join(storage_location, f"filtered_data_{timestamp}.csv")
694
- filtered_df.to_csv(filtered_csv, index=False)
695
- logging.info(f"Data filtered on column '{column}' for words {words}.")
696
- return f"Data filtered and saved to {filtered_csv}."
697
- except mysql.connector.Error as err:
698
- logging.error(f"Error fetching data from database: {err}")
699
- # Fallback to CSV
700
- else:
701
- logging.info("No database connection. Filtering data from CSV.")
702
-
703
- # Fallback to CSV
704
- csv_files = [f for f in os.listdir(storage_location) if f.endswith("_changes.csv") or f.endswith("_filtered.csv") or f.endswith("_sorted_asc.csv") or f.endswith("_sorted_desc.csv")]
705
- if not csv_files:
706
- return "No CSV files found to filter."
707
-
708
- # Assume the latest CSV is the target
709
- latest_csv = max([os.path.join(storage_location, f) for f in csv_files], key=os.path.getmtime)
710
- df = pd.read_csv(latest_csv)
711
-
712
- if column not in df.columns:
713
- return f"Column '{column}' does not exist in the data."
714
-
715
- filtered_df = df[df[column].astype(str).str.contains('|'.join(words), case=False, na=False)]
716
-
717
- if filtered_df.empty:
718
- return f"No records found with words {words} in column '{column}'."
719
-
720
- # Save the filtered data to a new CSV
721
- timestamp = int(time.time())
722
- filtered_csv = latest_csv.replace(".csv", f"_filtered_{timestamp}.csv")
723
- filtered_df.to_csv(filtered_csv, index=False)
724
- logging.info(f"Data filtered on column '{column}' for words {words}.")
725
- return f"Data filtered and saved to {filtered_csv}."
726
- except Exception as e:
727
- logging.error(f"Error filtering data: {e}")
728
- return f"Error filtering data: {e}"
729
-
730
- def sort_data(column: str, order: str) -> str:
731
- """
732
- Sorts the scraped data based on the specified column and order.
733
- Saves the sorted data to a new CSV file.
734
- """
735
- try:
736
- storage_location = DEFAULT_FILE_PATH
737
-
738
- connection = get_db_connection()
739
- if connection:
740
- try:
741
- cursor = connection.cursor(dictionary=True)
742
- # Fetch all data
743
- query = "SELECT * FROM scraped_data"
744
- cursor.execute(query)
745
- results = cursor.fetchall()
746
-
747
- if not results:
748
- return "No data available to sort."
749
-
750
- df = pd.DataFrame(results)
751
- if column not in df.columns:
752
- return f"Column '{column}' does not exist in the data."
753
-
754
- ascending = True if order.lower() == "ascending" else False
755
- sorted_df = df.sort_values(by=column, ascending=ascending)
756
-
757
- # Save the sorted data to a new CSV
758
- timestamp = int(time.time())
759
- sorted_csv = os.path.join(storage_location, f"sorted_data_{column}_{order.lower()}_{timestamp}.csv")
760
- sorted_df.to_csv(sorted_csv, index=False)
761
- logging.info(f"Data sorted on column '{column}' in {order} order.")
762
- return f"Data sorted and saved to {sorted_csv}."
763
- except mysql.connector.Error as err:
764
- logging.error(f"Error fetching data from database: {err}")
765
- # Fallback to CSV
766
- else:
767
- logging.info("No database connection. Sorting data from CSV.")
768
-
769
- # Fallback to CSV
770
- csv_files = [f for f in os.listdir(storage_location) if f.endswith("_changes.csv") or f.endswith("_filtered.csv") or f.endswith("_sorted_asc.csv") or f.endswith("_sorted_desc.csv")]
771
- if not csv_files:
772
- return "No CSV files found to sort."
773
-
774
- # Assume the latest CSV is the target
775
- latest_csv = max([os.path.join(storage_location, f) for f in csv_files], key=os.path.getmtime)
776
- df = pd.read_csv(latest_csv)
777
-
778
- if column not in df.columns:
779
- return f"Column '{column}' does not exist in the data."
780
-
781
- ascending = True if order.lower() == "ascending" else False
782
- sorted_df = df.sort_values(by=column, ascending=ascending)
783
-
784
- # Save the sorted data to a new CSV
785
- timestamp = int(time.time())
786
- sorted_csv = latest_csv.replace(".csv", f"_sorted_{order.lower()}_{timestamp}.csv")
787
- sorted_df.to_csv(sorted_csv, index=False)
788
- logging.info(f"Data sorted on column '{column}' in {order} order.")
789
- return f"Data sorted and saved to {sorted_csv}."
790
- except Exception as e:
791
- logging.error(f"Error sorting data: {e}")
792
- return f"Error sorting data: {e}"
793
-
794
- def export_csv(filename: str) -> str:
795
- """
796
- Exports the latest scraped data to a specified CSV filename.
797
- """
798
- try:
799
- storage_location = DEFAULT_FILE_PATH
800
-
801
- connection = get_db_connection()
802
- if connection:
803
- try:
804
- cursor = connection.cursor(dictionary=True)
805
- # Fetch all data
806
- query = "SELECT * FROM scraped_data"
807
- cursor.execute(query)
808
- results = cursor.fetchall()
809
-
810
- if not results:
811
- return "No data available to export."
812
-
813
- df = pd.DataFrame(results)
814
- export_path = os.path.join(storage_location, filename)
815
- df.to_csv(export_path, index=False)
816
- logging.info(f"Data exported to {export_path}.")
817
- return f"Data exported to {export_path}."
818
- except mysql.connector.Error as err:
819
- logging.error(f"Error exporting data from database: {err}")
820
- # Fallback to CSV
821
- else:
822
- logging.info("No database connection. Exporting data from CSV.")
823
-
824
- # Fallback to CSV
825
- csv_files = [f for f in os.listdir(storage_location) if f.endswith("_changes.csv") or f.endswith("_filtered.csv") or f.endswith("_sorted_asc.csv") or f.endswith("_sorted_desc.csv")]
826
- if not csv_files:
827
- return "No CSV files found to export."
828
-
829
- # Assume the latest CSV is the target
830
- latest_csv = max([os.path.join(storage_location, f) for f in csv_files], key=os.path.getmtime)
831
- df = pd.read_csv(latest_csv)
832
- export_path = os.path.join(storage_location, filename)
833
- df.to_csv(export_path, index=False)
834
- logging.info(f"Data exported to {export_path}.")
835
- return f"Data exported to {export_path}."
836
- except Exception as e:
837
- logging.error(f"Error exporting CSV: {e}")
838
- return f"Error exporting CSV: {e}"
839
-
840
- def log_action(action: str) -> str:
841
- """
842
- Logs a custom action message to the MySQL database or CSV.
843
  """
844
- try:
845
- connection = get_db_connection()
846
- if connection:
847
- try:
848
- cursor = connection.cursor()
849
- insert_query = """
850
- INSERT INTO action_logs (action)
851
- VALUES (%s)
852
- """
853
- cursor.execute(insert_query, (action,))
854
- connection.commit()
855
- logging.info(f"Action logged in database: {action}")
856
- cursor.close()
857
- connection.close()
858
- return f"Action logged: {action}"
859
- except mysql.connector.Error as err:
860
- logging.error(f"Error logging action to database: {err}")
861
- # Fallback to CSV
862
- else:
863
- logging.info("No database connection. Logging action to CSV.")
864
-
865
- # Fallback to CSV
866
- storage_location = DEFAULT_FILE_PATH
867
- try:
868
- os.makedirs(storage_location, exist_ok=True)
869
- csv_file_path = os.path.join(storage_location, "action_logs.csv")
870
- file_exists = os.path.isfile(csv_file_path)
871
-
872
- with open(csv_file_path, "a", newline="", encoding="utf-8") as csvfile:
873
- fieldnames = ["timestamp", "action"]
874
- writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
875
- if not file_exists:
876
- writer.writeheader()
877
- writer.writerow(
878
- {
879
- "timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
880
- "action": action,
881
- }
882
- )
883
- logging.info(f"Action logged to CSV: {action}")
884
- return f"Action logged: {action}"
885
- except Exception as e:
886
- logging.error(f"Error logging action to CSV: {e}")
887
- return f"Error logging action: {e}"
888
- except Exception as e:
889
- logging.error(f"Error logging action: {e}")
890
- return f"Error logging action: {e}"
891
-
892
- # Function to get the latest CSV file based on modification time
893
- def get_latest_csv() -> str:
894
- """
895
- Retrieves the latest CSV file from the storage directory based on modification time.
896
  """
897
  try:
898
- storage_location = "/home/users/app/scraped_data"
899
- csv_files = [f for f in os.listdir(storage_location) if f.endswith(".csv")]
900
- if not csv_files:
901
- return None
902
-
903
- latest_csv = max([os.path.join(storage_location, f) for f in csv_files], key=os.path.getmtime)
904
- return latest_csv
905
- except Exception as e:
906
- logging.error(f"Error retrieving latest CSV: {e}")
907
- return None
908
-
909
- def respond(
910
- message: str,
911
- history: list,
912
- system_message: str,
913
- max_tokens: int,
914
- temperature: float,
915
- top_p: float,
916
- ) -> str:
917
- """
918
- Generates a response using OpenLlamaForCausalLM.
919
- """
920
- try:
921
- # Check if the message contains a command
922
- command, params = parse_command(message)
923
- if command:
924
- # Execute the corresponding function
925
- response = execute_command(command, params)
926
- else:
927
- # Generate a regular response using OpenLlama
928
- prompt = (
929
- f"System: {system_message}\n"
930
- f"History: {history}\n"
931
- f"User: {message}\n"
932
- f"Assistant:"
933
- )
934
- response = openllama_pipeline(
935
- prompt,
936
- max_length=max_tokens,
937
- temperature=temperature,
938
- top_p=top_p,
939
- )[0]["generated_text"]
940
-
941
-
942
- # Extract the assistant's reply
943
- response = response.split("Assistant:")[-1].strip()
944
- return response
945
- except Exception as e:
946
- logging.error(f"Error generating response: {e}")
947
- return "Error generating response."
948
-
949
- # Define the Gradio interface
950
- def create_interface() -> gr.Blocks():
951
- """
952
- Defines and returns the Gradio interface for the application.
953
- """
954
- with gr.Blocks() as demo:
955
- gr.Markdown("# All-in-One Scraper, Database, and RSS Feeder")
956
-
957
- with gr.Row():
958
- with gr.Column():
959
- # Scraping Controls
960
- storage_location = gr.Textbox(
961
- value=DEFAULT_FILE_PATH, label="Storage Location"
962
- )
963
- urls = gr.Textbox(
964
- label="URLs (comma separated)",
965
- placeholder="https://example.com, https://anotherexample.com",
966
- )
967
- scrape_interval = gr.Slider(
968
- minimum=1,
969
- maximum=60,
970
- value=5,
971
- step=1,
972
- label="Scrape Interval (minutes)",
973
- )
974
- content_type = gr.Radio(
975
- choices=["text", "media", "both"],
976
- value="text",
977
- label="Content Type",
978
- )
979
- selector = gr.Textbox(
980
- label="CSS Selector for Media (Optional)",
981
- placeholder="e.g., img.main-image",
982
- )
983
- start_button = gr.Button("Start Scraping")
984
- stop_button = gr.Button("Stop Scraping")
985
- status_output = gr.Textbox(
986
- label="Status Output", interactive=False, lines=2
987
- )
988
-
989
- with gr.Column():
990
- # Chat Interface
991
- chat_history = gr.Chatbot(label="Chat History")
992
- with gr.Row():
993
- message = gr.Textbox(label="Message", placeholder="Type your message here...")
994
- system_message = gr.Textbox(
995
- value="You are a helpful assistant.", label="System message"
996
- )
997
- max_tokens = gr.Slider(
998
- minimum=1,
999
- maximum=2048,
1000
- value=512,
1001
- step=1,
1002
- label="Max new tokens",
1003
- )
1004
- temperature = gr.Slider(
1005
- minimum=0.1,
1006
- maximum=4.0,
1007
- value=0.7,
1008
- step=0.1,
1009
- label="Temperature",
1010
- )
1011
- top_p = gr.Slider(
1012
- minimum=0.1,
1013
- maximum=1.0,
1014
- value=0.95,
1015
- step=0.05,
1016
- label="Top-p (nucleus sampling)",
1017
- )
1018
- response_box = gr.Textbox(label="Response", interactive=False, lines=2)
1019
-
1020
- with gr.Row():
1021
- with gr.Column():
1022
- # CSV Display Controls
1023
- selected_url_csv = gr.Textbox(
1024
- label="Select URL for CSV Content",
1025
- placeholder="https://example.com",
1026
- )
1027
- csv_button = gr.Button("Display CSV Content")
1028
- csv_content_output = gr.Textbox(
1029
- label="CSV Content Output", interactive=False, lines=10
1030
- )
1031
-
1032
- with gr.Column():
1033
- # RSS Feed Generation Controls
1034
- selected_url_rss = gr.Textbox(
1035
- label="Select URL for RSS Feed",
1036
- placeholder="https://example.com",
1037
- )
1038
- rss_button = gr.Button("Generate RSS Feed")
1039
- rss_output = gr.Textbox(
1040
- label="RSS Feed Output", interactive=False, lines=20
1041
- )
1042
-
1043
- # Historical Data View
1044
- with gr.Row():
1045
- historical_view_url = gr.Textbox(
1046
- label="Select URL for Historical Data",
1047
- placeholder="https://example.com",
1048
- )
1049
- historical_button = gr.Button("View Historical Data")
1050
- historical_output = gr.Dataframe(
1051
- headers=["ID", "URL", "Content Hash", "Change Detected"],
1052
- label="Historical Data",
1053
- interactive=False
1054
- )
1055
-
1056
-
1057
-
1058
- # Connect buttons to their respective functions
1059
- start_button.click(
1060
- fn=start_scraping,
1061
- inputs=[
1062
- storage_location,
1063
- urls,
1064
- scrape_interval,
1065
- content_type,
1066
- selector,
1067
-
1068
- ],
1069
- outputs=status_output,
1070
- )
1071
-
1072
- stop_button.click(fn=stop_scraping, outputs=status_output)
1073
-
1074
- csv_button.click(
1075
- fn=display_csv,
1076
- inputs=[storage_location, selected_url_csv],
1077
- outputs=csv_content_output,
1078
- )
1079
-
1080
- rss_button.click(
1081
- fn=generate_rss_feed,
1082
- inputs=[storage_location, selected_url_rss],
1083
- outputs=rss_output,
1084
- )
1085
-
1086
- historical_button.click(
1087
- fn=display_historical_data,
1088
- inputs=[storage_location, historical_view_url],
1089
- outputs=historical_output,
1090
- )
1091
-
1092
- # Connect message submission to the chat interface
1093
- def update_chat(message_input, history, system_msg, max_toks, temp, top_p_val):
1094
- if not message_input.strip():
1095
- return history, "Please enter a message."
1096
-
1097
- response = respond(
1098
- message_input,
1099
- history,
1100
- system_msg,
1101
- max_toks,
1102
- temp,
1103
- top_p_val,
1104
- )
1105
- history.append((message_input, response))
1106
- return history, response
1107
-
1108
- message.submit(
1109
- update_chat,
1110
- inputs=[
1111
- message,
1112
- chat_history,
1113
- system_message,
1114
- max_tokens,
1115
- temperature,
1116
- top_p,
1117
- ],
1118
- outputs=[chat_history, response_box],
1119
- )
1120
-
1121
- return demo
1122
-
1123
- # Function to display historical data
1124
- def display_historical_data(storage_location: str, url: str):
1125
- """
1126
- Retrieves and displays historical scraping data for a given URL.
1127
- """
1128
- try:
1129
- connection = get_db_connection()
1130
- if connection:
1131
- try:
1132
- cursor = connection.cursor(dictionary=True)
1133
- query = "SELECT * FROM scraped_data WHERE url = %s ORDER BY change_detected DESC"
1134
- cursor.execute(query, (url,))
1135
- results = cursor.fetchall()
1136
-
1137
- if not results:
1138
- return pd.DataFrame()
1139
-
1140
- df = pd.DataFrame(results)
1141
- cursor.close()
1142
- connection.close()
1143
- return df
1144
- except mysql.connector.Error as err:
1145
- logging.error(f"Error fetching historical data from database: {err}")
1146
- # Fallback to CSV
1147
- else:
1148
- logging.info("No database connection. Fetching historical data from CSV.")
1149
-
1150
- # Fallback to CSV
1151
- hostname = urlparse(url).hostname
1152
- csv_path = os.path.join(storage_location, f"{hostname}_changes.csv")
1153
- if os.path.exists(csv_path):
1154
- df = pd.read_csv(csv_path)
1155
- return df
1156
- else:
1157
- return pd.DataFrame()
1158
- except Exception as e:
1159
- logging.error(f"Error fetching historical data for {url}: {e}")
1160
- return pd.DataFrame()
1161
-
1162
- def load_model():
1163
- """
1164
- Loads the openLlama model and tokenizer once and returns the pipeline.
1165
- """
1166
- try:
1167
- model_name = "openlm-research/open_llama_3b_v2"
1168
  tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False, legacy=False)
1169
  model = AutoModelForCausalLM.from_pretrained(model_name)
1170
-
1171
  # This should be inside the try block
1172
- max_supported_length = 2048
1173
-
1174
  openllama_pipeline = pipeline(
1175
  "text-generation",
1176
  model=model,
@@ -1181,74 +34,10 @@ def display_historical_data(storage_location: str, url: str):
1181
  top_p=0.95,
1182
  device=0 if torch.cuda.is_available() else -1,
1183
  )
1184
- logging.info("Model loaded successfully.")
1185
- return openllama_pipeline # Return the pipeline
1186
  except Exception as e:
1187
- logging.error(f"Error loading google/flan-t5-xl model: {e}")
1188
  return None
1189
 
1190
- # Load the model once at the start
1191
- chat_pipeline = load_model()
1192
-
1193
- # Automated Testing using unittest
1194
- class TestApp(unittest.TestCase):
1195
- def test_parse_command_filter(self):
1196
- command = "Filter apples, oranges in column Description"
1197
- parsed_command = parse_command(command)
1198
- self.assertEqual(parsed_command[0], "filter")
1199
- self.assertListEqual(parsed_command[1]["words"], ["apples", "oranges"])
1200
- self.assertEqual(parsed_command[1]["column"], "Description")
1201
-
1202
- def test_parse_command_sort(self):
1203
- command = "Sort Price ascending"
1204
- parsed_command = parse_command(command)
1205
- self.assertEqual(parsed_command[0], "sort")
1206
- self.assertEqual(parsed_command[1]["column"], "Price")
1207
- self.assertEqual(parsed_command[1]["order"], "ascending")
1208
-
1209
- def test_parse_command_export(self):
1210
- command = "Export to CSV as filtered_data.csv"
1211
- parsed_command = parse_command(command)
1212
- self.assertEqual(parsed_command[0], "export")
1213
- self.assertEqual(parsed_command[1]["filename"], "filtered_data.csv")
1214
-
1215
- def test_parse_command_log(self):
1216
- command = "Log action Filtered data for specific fruits"
1217
- parsed_command = parse_command(command)
1218
- self.assertEqual(parsed_command[0], "log")
1219
- self.assertEqual(parsed_command[1]["action"], "Filtered data for specific fruits")
1220
-
1221
- def test_database_connection(self):
1222
- connection = get_db_connection()
1223
- # Connection may be None if not configured; adjust the test accordingly
1224
- if connection:
1225
- self.assertTrue(connection.is_connected())
1226
- connection.close()
1227
- else:
1228
- self.assertIsNone(connection)
1229
-
1230
- def main():
1231
- # Initialize and run the application
1232
- logging.info("Starting the application...")
1233
- model = load_model()
1234
- if model:
1235
- logging.info("Application started successfully.")
1236
- print("Main function executed")
1237
- print("Creating interface...")
1238
- demo = create_interface()
1239
- print("Launching interface...")
1240
- demo.launch(server_name="0.0.0.0", server_port=7860)
1241
- else:
1242
- logging.error("Failed to start the application.")
1243
-
1244
- # Main execution
1245
- if __name__ == "__main__":
1246
- # Initialize database
1247
- initialize_database()
1248
-
1249
- # Create and launch Gradio interface
1250
- demo = create_interface()
1251
- demo.launch()
1252
-
1253
- # Run automated tests
1254
- unittest.main(argv=[''], exit=False)
 
1
+ # ... (your existing imports and code before model loading) ...
 
 
 
 
 
 
 
 
 
 
 
2
 
3
+ # Dictionary to store model loading functions
4
+ model_loaders = {
5
+ "Falcon": lambda: load_model("tiiuae/falcon-7b"),
6
+ "Flan-T5": lambda: load_model("google/flan-t5-xl"),
7
+ # Add more models and their loading functions here
8
+ }
 
 
 
9
 
10
+ model_option = st.selectbox("Select a Model", list(model_loaders.keys()))
 
 
 
11
 
12
+ # Load the selected model
13
+ model = model_loaders[model_option]()
 
 
 
 
 
 
14
 
15
+ # ... (rest of your existing code) ...
16
 
 
17
 
18
+ def load_model(model_name: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  """
20
+ Loads the specified model and tokenizer.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  """
22
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False, legacy=False)
24
  model = AutoModelForCausalLM.from_pretrained(model_name)
 
25
  # This should be inside the try block
26
+ max_supported_length = 2048 # Get this from the model config
 
27
  openllama_pipeline = pipeline(
28
  "text-generation",
29
  model=model,
 
34
  top_p=0.95,
35
  device=0 if torch.cuda.is_available() else -1,
36
  )
37
+ logging.info(f"{model_name} loaded successfully.")
38
+ return openllama_pipeline
39
  except Exception as e:
40
+ logging.error(f"Error loading {model_name} model: {e}")
41
  return None
42
 
43
+ # ... (rest of your existing code) ...