pradeep6kumar2024 commited on
Commit
1710631
·
1 Parent(s): 1a8f82f

updated app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -41
app.py CHANGED
@@ -6,7 +6,7 @@ import time
6
 
7
  # Configuration
8
  BASE_MODEL = "microsoft/phi-2"
9
- ADAPTER_MODEL = "pradeep6kumar2024/phi2-qlora-assistant" # Your actual model ID
10
 
11
  class ModelWrapper:
12
  def __init__(self):
@@ -16,48 +16,76 @@ class ModelWrapper:
16
 
17
  def load_model(self):
18
  if not self.loaded:
19
- print("Loading model and tokenizer...")
20
- self.tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
21
- base_model = AutoModelForCausalLM.from_pretrained(
22
- BASE_MODEL,
23
- torch_dtype=torch.float16,
24
- device_map="auto",
25
- trust_remote_code=True
26
- )
27
-
28
- print("Loading LoRA adapter...")
29
- self.model = PeftModel.from_pretrained(base_model, ADAPTER_MODEL)
30
- self.loaded = True
31
- print("Model loading complete!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
- def generate_response(self, prompt, max_length=512, temperature=0.7, top_p=0.9, stream=False):
34
  if not self.loaded:
35
  self.load_model()
36
 
37
- # Tokenize input
38
- inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
39
- inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
40
-
41
- # Generate
42
- start_time = time.time()
43
- with torch.no_grad():
44
- outputs = self.model.generate(
45
- **inputs,
46
- max_length=max_length,
47
- temperature=temperature,
48
- top_p=top_p,
49
- do_sample=True,
50
- pad_token_id=self.tokenizer.pad_token_id,
51
- eos_token_id=self.tokenizer.eos_token_id
52
- )
53
-
54
- # Decode response
55
- response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
56
- if response.startswith(prompt):
57
- response = response[len(prompt):].strip()
58
-
59
- generation_time = time.time() - start_time
60
- return response, generation_time
 
 
 
 
 
 
 
 
 
 
61
 
62
  # Initialize model wrapper
63
  model_wrapper = ModelWrapper()
@@ -65,6 +93,9 @@ model_wrapper = ModelWrapper()
65
  def generate_text(prompt, max_length=512, temperature=0.7, top_p=0.9):
66
  """Gradio interface function"""
67
  try:
 
 
 
68
  response, gen_time = model_wrapper.generate_response(
69
  prompt,
70
  max_length=max_length,
@@ -73,7 +104,8 @@ def generate_text(prompt, max_length=512, temperature=0.7, top_p=0.9):
73
  )
74
  return f"Generated in {gen_time:.2f} seconds:\n\n{response}"
75
  except Exception as e:
76
- return f"Error generating response: {str(e)}"
 
77
 
78
  # Create the Gradio interface
79
  demo = gr.Interface(
@@ -159,6 +191,5 @@ demo = gr.Interface(
159
  cache_examples=False
160
  )
161
 
162
- # Launch with sharing enabled
163
  if __name__ == "__main__":
164
  demo.launch()
 
6
 
7
  # Configuration
8
  BASE_MODEL = "microsoft/phi-2"
9
+ ADAPTER_MODEL = "pradeep6kumar2024/phi2-qlora-assistant"
10
 
11
  class ModelWrapper:
12
  def __init__(self):
 
16
 
17
  def load_model(self):
18
  if not self.loaded:
19
+ try:
20
+ print("Loading tokenizer...")
21
+ self.tokenizer = AutoTokenizer.from_pretrained(
22
+ BASE_MODEL,
23
+ trust_remote_code=True,
24
+ padding_side="left"
25
+ )
26
+ self.tokenizer.pad_token = self.tokenizer.eos_token
27
+
28
+ print("Loading base model...")
29
+ base_model = AutoModelForCausalLM.from_pretrained(
30
+ BASE_MODEL,
31
+ torch_dtype=torch.float16,
32
+ device_map="auto",
33
+ trust_remote_code=True,
34
+ use_flash_attention_2=False # Disable flash attention if causing issues
35
+ )
36
+
37
+ print("Loading LoRA adapter...")
38
+ self.model = PeftModel.from_pretrained(
39
+ base_model,
40
+ ADAPTER_MODEL,
41
+ torch_dtype=torch.float16,
42
+ device_map="auto"
43
+ )
44
+ self.model.eval()
45
+ print("Model loading complete!")
46
+ self.loaded = True
47
+ except Exception as e:
48
+ print(f"Error during model loading: {str(e)}")
49
+ raise
50
 
51
+ def generate_response(self, prompt, max_length=512, temperature=0.7, top_p=0.9):
52
  if not self.loaded:
53
  self.load_model()
54
 
55
+ try:
56
+ # Tokenize input
57
+ inputs = self.tokenizer(
58
+ prompt,
59
+ return_tensors="pt",
60
+ truncation=True,
61
+ max_length=512,
62
+ padding=True
63
+ ).to(self.model.device)
64
+
65
+ # Generate
66
+ start_time = time.time()
67
+ with torch.no_grad():
68
+ outputs = self.model.generate(
69
+ **inputs,
70
+ max_length=max_length,
71
+ temperature=temperature,
72
+ top_p=top_p,
73
+ do_sample=True,
74
+ pad_token_id=self.tokenizer.pad_token_id,
75
+ eos_token_id=self.tokenizer.eos_token_id,
76
+ repetition_penalty=1.1
77
+ )
78
+
79
+ # Decode response
80
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
81
+ if response.startswith(prompt):
82
+ response = response[len(prompt):].strip()
83
+
84
+ generation_time = time.time() - start_time
85
+ return response, generation_time
86
+ except Exception as e:
87
+ print(f"Error during generation: {str(e)}")
88
+ raise
89
 
90
  # Initialize model wrapper
91
  model_wrapper = ModelWrapper()
 
93
  def generate_text(prompt, max_length=512, temperature=0.7, top_p=0.9):
94
  """Gradio interface function"""
95
  try:
96
+ if not prompt.strip():
97
+ return "Please enter a prompt."
98
+
99
  response, gen_time = model_wrapper.generate_response(
100
  prompt,
101
  max_length=max_length,
 
104
  )
105
  return f"Generated in {gen_time:.2f} seconds:\n\n{response}"
106
  except Exception as e:
107
+ print(f"Error in generate_text: {str(e)}")
108
+ return f"Error generating response: {str(e)}\nPlease try again with a different prompt or parameters."
109
 
110
  # Create the Gradio interface
111
  demo = gr.Interface(
 
191
  cache_examples=False
192
  )
193
 
 
194
  if __name__ == "__main__":
195
  demo.launch()