prnvtripathi14 commited on
Commit
a40ec6a
·
verified ·
1 Parent(s): 026025c
Files changed (1) hide show
  1. app.py +24 -31
app.py CHANGED
@@ -8,41 +8,34 @@ logging.basicConfig(level=logging.DEBUG,
8
  format='%(asctime)s - %(levelname)s - %(message)s')
9
  logger = logging.getLogger(__name__)
10
 
11
- # Models to try
12
- MODELS_TO_TRY = [
13
- "google/flan-t5-xxl", # Powerful instruction-following model
14
- "bigscience/T0pp", # Optimized for zero-shot tasks
15
- "t5-large", # General-purpose text generation
16
- "google/flan-t5-large" # Lightweight instruction-tuned model
17
- ]
18
 
19
  def load_model():
20
  """
21
- Attempt to load a suitable model for text generation.
22
  """
23
- for model_name in MODELS_TO_TRY:
24
- try:
25
- logger.info(f"Attempting to load model: {model_name}")
26
-
27
- # Load model and tokenizer
28
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
29
- tokenizer = AutoTokenizer.from_pretrained(model_name)
30
-
31
- # Create the text generation pipeline
32
- generator = pipeline(
33
- "text2text-generation",
34
- model=model,
35
- tokenizer=tokenizer,
36
- max_length=512,
37
- num_return_sequences=1
38
- )
39
- logger.info(f"Successfully loaded model: {model_name}")
40
- return generator
41
- except Exception as e:
42
- logger.error(f"Failed to load model {model_name}: {e}")
43
-
44
- logger.error("All model attempts failed. No model loaded.")
45
- return None
46
 
47
  # Load the generator at startup
48
  generator = load_model()
 
8
  format='%(asctime)s - %(levelname)s - %(message)s')
9
  logger = logging.getLogger(__name__)
10
 
11
+ # Model to use
12
+ MODEL_NAME = "google/flan-t5-large"
 
 
 
 
 
13
 
14
  def load_model():
15
  """
16
+ Load the selected model and tokenizer using PyTorch.
17
  """
18
+ try:
19
+ logger.info(f"Loading model: {MODEL_NAME} with PyTorch backend")
20
+
21
+ # Load the model and tokenizer
22
+ model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, framework="pt")
23
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
24
+
25
+ # Create the text generation pipeline
26
+ generator = pipeline(
27
+ "text2text-generation",
28
+ model=model,
29
+ tokenizer=tokenizer,
30
+ framework="pt", # Specify PyTorch framework
31
+ max_length=512,
32
+ num_return_sequences=1
33
+ )
34
+ logger.info(f"Successfully loaded model: {MODEL_NAME}")
35
+ return generator
36
+ except Exception as e:
37
+ logger.error(f"Failed to load model {MODEL_NAME}: {e}")
38
+ return None
 
 
39
 
40
  # Load the generator at startup
41
  generator = load_model()