CEEMEESEEK / app.py
acecalisto3's picture
Automated update: Implement model swapping
7321c1c
raw
history blame
1.42 kB
# ... (your existing imports and code before model loading) ...
# Dictionary to store model loading functions
model_loaders = {
"Falcon": lambda: load_model("tiiuae/falcon-7b"),
"Flan-T5": lambda: load_model("google/flan-t5-xl"),
# Add more models and their loading functions here
}
model_option = st.selectbox("Select a Model", list(model_loaders.keys()))
# Load the selected model
model = model_loaders[model_option]()
# ... (rest of your existing code) ...
def load_model(model_name: str):
"""
Loads the specified model and tokenizer.
"""
try:
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False, legacy=False)
model = AutoModelForCausalLM.from_pretrained(model_name)
# This should be inside the try block
max_supported_length = 2048 # Get this from the model config
openllama_pipeline = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
truncation=True,
max_length=max_supported_length,
temperature=0.7,
top_p=0.95,
device=0 if torch.cuda.is_available() else -1,
)
logging.info(f"{model_name} loaded successfully.")
return openllama_pipeline
except Exception as e:
logging.error(f"Error loading {model_name} model: {e}")
return None
# ... (rest of your existing code) ...