acecalisto3 commited on
Commit
d507c63
1 Parent(s): 4ca234f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -32
app.py CHANGED
@@ -31,41 +31,19 @@ import xml.etree.ElementTree as ET
31
  import torch
32
  import mysql.connector
33
  from mysql.connector import errorcode, pooling
34
- from dotenv import load_dotenv
35
- from huggingface_hub import login
36
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
37
 
 
38
 
39
- def load_model(): # Define load_model() first
40
- """
41
- Loads the openLlama model and tokenizer once and returns the pipeline.
42
- """
43
- try:
44
- model_name = "openlm-research/open_llama_3b_v2"
45
- tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False, legacy=False)
46
- model = AutoModelForCausalLM.from_pretrained(model_name)
47
-
48
- max_supported_length = 2048
49
-
50
- openllama_pipeline = pipeline(
51
- "text-generation",
52
- model=model,
53
- tokenizer=tokenizer,
54
- truncation=True,
55
- max_length=max_supported_length,
56
- temperature=0.7,
57
- top_p=0.95,
58
- device=0 if torch.cuda.is_available() else -1,
59
- )
60
- logging.info("Model loaded successfully.")
61
- return openllama_pipeline
62
- except Exception as e:
63
- logging.error(f"Error loading google/flan-t5-xl model: {e}")
64
- return None
65
-
66
- chat_pipeline = load_model() # Now call load_model()
67
 
68
- nlp = AutoTokenizer.from_pretrained("bert-base-uncased")
 
 
 
 
 
69
 
70
  HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
71
 
 
31
  import torch
32
  import mysql.connector
33
  from mysql.connector import errorcode, pooling
34
+ import nltk
35
+ import importlib
 
36
 
37
+ st.title("CEEMEESEEK with Model Selection")
38
 
39
+ model_option = st.selectbox("Select a Model", ["Falcon", "Flan-T5", "Other Model"]) # Add your model names
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
+ if model_option == "Falcon":
42
+ model_module = importlib.import_module("model_falcon") # Assuming you create model_falcon.py
43
+ model = model_module.load_falcon_model()
44
+ elif model_option == "Flan-T5":
45
+ model_module = importlib.import_module("model_flan_t5")
46
+ model = model_module.load_flan_t5_model()
47
 
48
  HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
49