acecalisto3 commited on
Commit
2db6252
1 Parent(s): debab79

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -34
app.py CHANGED
@@ -22,7 +22,7 @@ from selenium.common.exceptions import (
22
  StaleElementReferenceException,
23
  )
24
  from webdriver_manager.chrome import ChromeDriverManager
25
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
26
  import gradio as gr
27
  import xml.etree.ElementTree as ET
28
  import torch
@@ -31,23 +31,19 @@ from mysql.connector import errorcode, pooling
31
  from dotenv import load_dotenv
32
  from huggingface_hub import login
33
 
34
- # Load model directly
35
- from transformers import AutoTokenizer, AutoModelForMaskedLM
 
36
 
37
- tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-mpnet-base-v2", clean_up_tokenization_spaces=True)
38
- model = AutoModelForMaskedLM.from_pretrained("sentence-transformers/all-mpnet-base-v2")
39
-
40
- # Define classifier for zero-shot classification
41
- classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
 
42
 
43
- # Define nlp using a simple tokenizer
44
- from transformers import AutoTokenizer
45
  nlp = AutoTokenizer.from_pretrained("bert-base-uncased")
46
 
47
-
48
-
49
-
50
-
51
  HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
52
  if not HUGGINGFACE_TOKEN:
53
  raise ValueError("HUGGINGFACE_TOKEN is not set in the environment variables.")
@@ -907,7 +903,6 @@ def get_latest_csv() -> str:
907
  logging.error(f"Error retrieving latest CSV: {e}")
908
  return None
909
 
910
- # Chat Response Function with Dynamic Command Handling
911
  def respond(
912
  message: str,
913
  history: list,
@@ -917,12 +912,8 @@ def respond(
917
  top_p: float,
918
  ) -> str:
919
  """
920
- Generates a response using the google/flan-t5-xl model based on the user's message and history.
921
- Additionally, handles dynamic commands to interact with individual components.
922
  """
923
- if chat_pipeline is None:
924
- return "Error: Chat model is not loaded."
925
-
926
  try:
927
  # Check if the message contains a command
928
  command, params = parse_command(message)
@@ -930,20 +921,20 @@ def respond(
930
  # Execute the corresponding function
931
  response = execute_command(command, params)
932
  else:
933
- # Generate a regular response using the model
934
- prompt = (
935
- f"System: {system_message}\n"
936
- f"History: {history}\n"
937
- f"User: {message}\n"
938
- f"Assistant:"
939
- )
940
- response = chat_pipeline(
941
- prompt,
942
- max_length=max_tokens,
943
- temperature=temperature,
944
- top_p=top_p,
945
- num_return_sequences=1,
946
- )[0]["generated_text"]
947
 
948
  # Extract the assistant's reply
949
  response = response.split("Assistant:")[-1].strip()
 
22
  StaleElementReferenceException,
23
  )
24
  from webdriver_manager.chrome import ChromeDriverManager
25
+ from transformers import AutoTokenizer, OpenLlamaForCausalLM, pipeline
26
  import gradio as gr
27
  import xml.etree.ElementTree as ET
28
  import torch
 
31
  from dotenv import load_dotenv
32
  from huggingface_hub import login
33
 
34
+ model_name = "openlm-research/open_llama_3b_v2" # Or another OpenLlama variant
35
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
36
+ model = OpenLlamaForCausalLM.from_pretrained(model_name)
37
 
38
+ openllama_pipeline = pipeline(
39
+ "text-generation",
40
+ model=model,
41
+ tokenizer=tokenizer,
42
+ device=0 if torch.cuda.is_available() else -1 # Use GPU if available
43
+ )
44
 
 
 
45
  nlp = AutoTokenizer.from_pretrained("bert-base-uncased")
46
 
 
 
 
 
47
  HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
48
  if not HUGGINGFACE_TOKEN:
49
  raise ValueError("HUGGINGFACE_TOKEN is not set in the environment variables.")
 
903
  logging.error(f"Error retrieving latest CSV: {e}")
904
  return None
905
 
 
906
  def respond(
907
  message: str,
908
  history: list,
 
912
  top_p: float,
913
  ) -> str:
914
  """
915
+ Generates a response using OpenLlamaForCausalLM.
 
916
  """
 
 
 
917
  try:
918
  # Check if the message contains a command
919
  command, params = parse_command(message)
 
921
  # Execute the corresponding function
922
  response = execute_command(command, params)
923
  else:
924
+ # Generate a regular response using OpenLlama
925
+ prompt = (
926
+ f"System: {system_message}\n"
927
+ f"History: {history}\n"
928
+ f"User: {message}\n"
929
+ f"Assistant:"
930
+ )
931
+ response = openllama_pipeline(
932
+ prompt,
933
+ max_length=max_tokens,
934
+ temperature=temperature,
935
+ top_p=top_p,
936
+ )[0]["generated_text"]
937
+
938
 
939
  # Extract the assistant's reply
940
  response = response.split("Assistant:")[-1].strip()