rish13 commited on
Commit
892e1e5
·
verified ·
1 Parent(s): 1694eaa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -3
app.py CHANGED
@@ -7,16 +7,22 @@ model = pipeline("text-generation", model="rish13/polymers")
7
  def generate_response(prompt):
8
  # Generate text from the model
9
  response = model(prompt, max_length=150, num_return_sequences=1)
 
 
10
  generated_text = response[0]['generated_text']
11
 
12
  # Find the position of the first end-of-sentence punctuation
13
  end_punctuation = ['.', '!', '?']
14
- end_position = min((generated_text.find(punct) for punct in end_punctuation if punct in generated_text), default=-1)
 
 
 
 
15
 
 
16
  if end_position != -1:
17
- # Include the punctuation in the response
18
  generated_text = generated_text[:end_position + 1]
19
-
20
  return generated_text
21
 
22
  # Define the Gradio interface
 
7
  def generate_response(prompt):
8
  # Generate text from the model
9
  response = model(prompt, max_length=150, num_return_sequences=1)
10
+
11
+ # Get the generated text from the response
12
  generated_text = response[0]['generated_text']
13
 
14
  # Find the position of the first end-of-sentence punctuation
15
  end_punctuation = ['.', '!', '?']
16
+ end_position = -1
17
+ for punct in end_punctuation:
18
+ pos = generated_text.find(punct)
19
+ if pos != -1 and (end_position == -1 or pos < end_position):
20
+ end_position = pos
21
 
22
+ # If punctuation is found, truncate the text at that point
23
  if end_position != -1:
 
24
  generated_text = generated_text[:end_position + 1]
25
+
26
  return generated_text
27
 
28
  # Define the Gradio interface