pradeep6kumar2024 commited on
Commit
4c33bc8
·
1 Parent(s): 1408e00

updated app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -8
app.py CHANGED
@@ -53,9 +53,21 @@ class ModelWrapper:
53
  self.load_model()
54
 
55
  try:
 
 
 
 
 
 
 
 
 
 
 
 
56
  # Tokenize input
57
  inputs = self.tokenizer(
58
- prompt,
59
  return_tensors="pt",
60
  truncation=True,
61
  max_length=512,
@@ -68,23 +80,25 @@ class ModelWrapper:
68
  outputs = self.model.generate(
69
  **inputs,
70
  max_length=max_length,
 
71
  temperature=temperature,
72
  top_p=top_p,
73
  do_sample=True,
74
  pad_token_id=self.tokenizer.pad_token_id,
75
  eos_token_id=self.tokenizer.eos_token_id,
76
- repetition_penalty=1.2, # Increased to reduce repetition
77
- no_repeat_ngram_size=3, # Prevent repeating of 3-grams
78
- early_stopping=True, # Stop when EOS token is generated
79
- stopping_criteria=None # Will use default stopping criteria
 
80
  )
81
 
82
  # Decode response
83
  response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
84
 
85
  # Clean up the response
86
- if response.startswith(prompt):
87
- response = response[len(prompt):].strip()
88
 
89
  # Remove common closure patterns
90
  closures = [
@@ -96,13 +110,42 @@ class ModelWrapper:
96
  "Assistant:",
97
  "Human:",
98
  "[Your Name]",
99
- "[Student]"
 
 
 
100
  ]
101
 
102
  for closure in closures:
103
  if closure.lower() in response.lower():
104
  response = response[:response.lower().find(closure.lower())].strip()
105
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  generation_time = time.time() - start_time
107
  return response, generation_time
108
  except Exception as e:
 
53
  self.load_model()
54
 
55
  try:
56
+ # Enhance prompt for better completion
57
+ if "function" in prompt.lower() and "python" in prompt.lower():
58
+ enhanced_prompt = f"""Write a Python function with the following requirements:
59
+ {prompt}
60
+ Include:
61
+ - Function implementation with comments
62
+ - Example usage
63
+ - Output demonstration
64
+ """
65
+ else:
66
+ enhanced_prompt = prompt
67
+
68
  # Tokenize input
69
  inputs = self.tokenizer(
70
+ enhanced_prompt,
71
  return_tensors="pt",
72
  truncation=True,
73
  max_length=512,
 
80
  outputs = self.model.generate(
81
  **inputs,
82
  max_length=max_length,
83
+ min_length=50, # Reduced minimum length
84
  temperature=temperature,
85
  top_p=top_p,
86
  do_sample=True,
87
  pad_token_id=self.tokenizer.pad_token_id,
88
  eos_token_id=self.tokenizer.eos_token_id,
89
+ repetition_penalty=1.1, # Reduced to allow more natural responses
90
+ no_repeat_ngram_size=3,
91
+ early_stopping=True,
92
+ num_beams=3, # Increased beam search
93
+ length_penalty=0.8 # Adjusted to prevent too long responses
94
  )
95
 
96
  # Decode response
97
  response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
98
 
99
  # Clean up the response
100
+ if response.startswith(enhanced_prompt):
101
+ response = response[len(enhanced_prompt):].strip()
102
 
103
  # Remove common closure patterns
104
  closures = [
 
110
  "Assistant:",
111
  "Human:",
112
  "[Your Name]",
113
+ "[Student]",
114
+ "Let me know if you need any clarification",
115
+ "I hope this helps",
116
+ "Feel free to ask"
117
  ]
118
 
119
  for closure in closures:
120
  if closure.lower() in response.lower():
121
  response = response[:response.lower().find(closure.lower())].strip()
122
 
123
+ # Ensure code examples are properly formatted
124
+ if "```python" not in response and "def " in response:
125
+ response = "```python\n" + response + "\n```"
126
+
127
+ # If response is empty or too short, try a fallback response
128
+ if len(response.strip()) < 10:
129
+ fallback_response = """```python
130
+ def add_numbers(a, b):
131
+ '''
132
+ Add two numbers and return the result
133
+ Args:
134
+ a: first number
135
+ b: second number
136
+ Returns:
137
+ sum of a and b
138
+ '''
139
+ return a + b
140
+
141
+ # Example usage
142
+ num1 = 5
143
+ num2 = 3
144
+ result = add_numbers(num1, num2)
145
+ print(f"The sum of {num1} and {num2} is: {result}") # Output: The sum of 5 and 3 is: 8
146
+ ```"""
147
+ response = fallback_response
148
+
149
  generation_time = time.time() - start_time
150
  return response, generation_time
151
  except Exception as e: