YALCINKAYA commited on
Commit
9143358
·
verified ·
1 Parent(s): cbe8e48

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -7
app.py CHANGED
@@ -28,8 +28,19 @@ def get_model_and_tokenizer(model_id):
28
  tokenizer.pad_token = tokenizer.eos_token
29
 
30
  print(f"Loading model for model_id: {model_id} on {device}")
31
- model = AutoModelForCausalLM.from_pretrained(model_id).to(device) # Move model to GPU
32
- model.config.use_cache = False
 
 
 
 
 
 
 
 
 
 
 
33
  except Exception as e:
34
  print(f"Error loading model: {e}")
35
  raise e # Raise the error to be caught in the POST request
@@ -40,9 +51,8 @@ def generate_response(user_input, model_id):
40
  # Ensure model and tokenizer are loaded
41
  get_model_and_tokenizer(model_id)
42
 
43
- prompt = user_input
44
- inputs = tokenizer([prompt], return_tensors="pt").to(device) # Move inputs to GPU
45
-
46
  generation_config = GenerationConfig(
47
  penalty_alpha=0.6,
48
  do_sample=True,
@@ -55,8 +65,10 @@ def generate_response(user_input, model_id):
55
  stop_sequences=["User:", "Assistant:", "\n"],
56
  )
57
 
58
- outputs = model.generate(**inputs, generation_config=generation_config)
59
- response = tokenizer.decode(outputs[:, inputs['input_ids'].shape[-1]:][0], skip_special_tokens=True)
 
 
60
  cleaned_response = response.replace("User:", "").replace("Assistant:", "").strip()
61
  return cleaned_response.strip().split("\n")[0] # Keep only the first line of response
62
 
 
28
  tokenizer.pad_token = tokenizer.eos_token
29
 
30
  print(f"Loading model for model_id: {model_id} on {device}")
31
+
32
+ bnb_config = BitsAndBytesConfig(
33
+ load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype="float16", bnb_4bit_use_double_quant=True
34
+ )
35
+
36
+ model = AutoModelForCausalLM.from_pretrained(
37
+ model_id, quantization_config=bnb_config, device_map="auto"
38
+ )
39
+
40
+ model.config.use_cache=False
41
+ model.config.pretraining_tp=1
42
+
43
+
44
  except Exception as e:
45
  print(f"Error loading model: {e}")
46
  raise e # Raise the error to be caught in the POST request
 
51
  # Ensure model and tokenizer are loaded
52
  get_model_and_tokenizer(model_id)
53
 
54
+ prompt = user_input
55
+
 
56
  generation_config = GenerationConfig(
57
  penalty_alpha=0.6,
58
  do_sample=True,
 
65
  stop_sequences=["User:", "Assistant:", "\n"],
66
  )
67
 
68
+
69
+ inputs = tokenizer(prompt, return_tensors="pt").to('cuda')
70
+ response = (tokenizer.decode(outputs[0], skip_special_tokens=True))
71
+
72
  cleaned_response = response.replace("User:", "").replace("Assistant:", "").strip()
73
  return cleaned_response.strip().split("\n")[0] # Keep only the first line of response
74