YALCINKAYA commited on
Commit
4721a1c
·
1 Parent(s): f4c3c98

bug fix for methods

Browse files
Files changed (1) hide show
  1. app.py +19 -18
app.py CHANGED
@@ -18,22 +18,20 @@ tokenizer = None
18
 
19
  def get_model_and_tokenizer(model_id):
20
  global model, tokenizer
21
- try:
22
- print(f"Loading tokenizer for model_id: {model_id}")
23
- # Load the tokenizer
24
- tokenizer = AutoTokenizer.from_pretrained(model_id)
25
- tokenizer.pad_token = tokenizer.eos_token
26
-
27
- print(f"Loading model and for model_id: {model_id}")
28
- # Load the model
29
- model = AutoModelForCausalLM.from_pretrained(model_id) #, device_map="auto")
30
- model.config.use_cache = False
31
-
32
- except Exception as e:
33
- print(f"Error loading model: {e}")
34
-
35
- return "No complete blocks found. Please check the format of the response."
36
-
37
 
38
  # max_new_tokens=100,
39
  # min_length=5,
@@ -70,7 +68,10 @@ def get_model_and_tokenizer(model_id):
70
  #truncation=True, # Enable truncation for longer prompts
71
  #
72
 
73
- def generate_response(user_input):
 
 
 
74
  prompt = formatted_prompt(user_input)
75
  inputs = tokenizer([prompt], return_tensors="pt")
76
 
@@ -87,7 +88,7 @@ def generate_response(user_input):
87
 
88
  outputs = model.generate(**inputs, generation_config=generation_config)
89
  response = tokenizer.decode(outputs[:, inputs['input_ids'].shape[-1]:][0], skip_special_tokens=True)
90
- return response.strip().split("Assistant:")[-1].strip() # Get the part after 'Assistant:'
91
 
92
  def formatted_prompt(question) -> str:
93
  return f"<|startoftext|>User: {question}\nAssistant:"
 
18
 
19
  def get_model_and_tokenizer(model_id):
20
  global model, tokenizer
21
+ if model is None or tokenizer is None:
22
+ try:
23
+ print(f"Loading tokenizer for model_id: {model_id}")
24
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
25
+ tokenizer.pad_token = tokenizer.eos_token
26
+
27
+ print(f"Loading model for model_id: {model_id}")
28
+ model = AutoModelForCausalLM.from_pretrained(model_id)
29
+ model.config.use_cache = False
30
+ except Exception as e:
31
+ print(f"Error loading model: {e}")
32
+ raise e # Raise the error to be caught in the POST request
33
+ else:
34
+ print(f"Model and tokenizer for {model_id} are already loaded.")
 
 
35
 
36
  # max_new_tokens=100,
37
  # min_length=5,
 
68
  #truncation=True, # Enable truncation for longer prompts
69
  #
70
 
71
+ def generate_response(user_input, model_id):
72
+ # Ensure model and tokenizer are loaded
73
+ get_model_and_tokenizer(model_id) # Load the model/tokenizer if not already loaded
74
+
75
  prompt = formatted_prompt(user_input)
76
  inputs = tokenizer([prompt], return_tensors="pt")
77
 
 
88
 
89
  outputs = model.generate(**inputs, generation_config=generation_config)
90
  response = tokenizer.decode(outputs[:, inputs['input_ids'].shape[-1]:][0], skip_special_tokens=True)
91
+ return response.strip().split("Assistant:")[-1].strip()
92
 
93
  def formatted_prompt(question) -> str:
94
  return f"<|startoftext|>User: {question}\nAssistant:"