File size: 1,422 Bytes
7321c1c
df717a9
7321c1c
 
 
 
 
 
03fef6f
7321c1c
03fef6f
7321c1c
 
d409826
7321c1c
ef088bc
03fef6f
7321c1c
df717a9
7321c1c
df717a9
 
972bb7b
e98b8b3
 
7321c1c
e98b8b3
 
 
 
 
 
726b8a1
 
d45cc49
 
7321c1c
 
d45cc49
7321c1c
d45cc49
 
7321c1c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
# ... (your existing imports and code before model loading) ...

# Dictionary to store model loading functions
model_loaders = {
    "Falcon": lambda: load_model("tiiuae/falcon-7b"),
    "Flan-T5": lambda: load_model("google/flan-t5-xl"),
    # Add more models and their loading functions here
}

model_option = st.selectbox("Select a Model", list(model_loaders.keys()))

# Load the selected model
model = model_loaders[model_option]()

# ... (rest of your existing code) ...


def load_model(model_name: str):
    """
    Loads the specified model and tokenizer.
    """
    try:
        tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False, legacy=False)
        model = AutoModelForCausalLM.from_pretrained(model_name)
        # This should be inside the try block
        max_supported_length = 2048  # Get this from the model config
        openllama_pipeline = pipeline(
            "text-generation",
            model=model,
            tokenizer=tokenizer,
            truncation=True,
            max_length=max_supported_length,
            temperature=0.7,
            top_p=0.95,
            device=0 if torch.cuda.is_available() else -1,
        )
        logging.info(f"{model_name} loaded successfully.")
        return openllama_pipeline
    except Exception as e:
        logging.error(f"Error loading {model_name} model: {e}")
        return None

# ... (rest of your existing code) ...