Bey007 commited on
Commit
882e739
·
verified ·
1 Parent(s): 00e8e64

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -7
app.py CHANGED
@@ -14,26 +14,31 @@ gpt2_model = AutoModelForCausalLM.from_pretrained("gpt2-medium")
14
 
15
  emotion_classifier = pipeline("text-classification", model="bhadresh-savani/distilbert-base-uncased-emotion", return_all_scores=True)
16
 
17
- # Function to generate a comforting story using GPT-2
18
  def generate_story(theme):
19
  # A detailed prompt for generating a comforting story about the selected theme
20
  story_prompt = f"Write a comforting, detailed, and heartwarming story about {theme}. The story should include a character who faces a tough challenge, finds hope, and ultimately overcomes the situation with a positive resolution."
21
 
22
- # Generate story using GPT-2
23
  input_ids = gpt2_tokenizer.encode(story_prompt, return_tensors='pt')
24
 
25
  story_ids = gpt2_model.generate(
26
  input_ids,
27
- max_length=500, # Generate longer stories
28
- temperature=0.8, # Balanced creativity
29
- top_p=0.9,
30
- repetition_penalty=1.2,
 
31
  num_return_sequences=1
32
  )
33
 
34
  # Decode the generated text
35
  story = gpt2_tokenizer.decode(story_ids[0], skip_special_tokens=True)
36
- return story
 
 
 
 
 
37
 
38
  def generate_response(user_input):
39
  # Limit user input length to prevent overflow issues
 
14
 
15
  emotion_classifier = pipeline("text-classification", model="bhadresh-savani/distilbert-base-uncased-emotion", return_all_scores=True)
16
 
 
17
  def generate_story(theme):
18
  # A detailed prompt for generating a comforting story about the selected theme
19
  story_prompt = f"Write a comforting, detailed, and heartwarming story about {theme}. The story should include a character who faces a tough challenge, finds hope, and ultimately overcomes the situation with a positive resolution."
20
 
21
+ # Generate story using GPT-2 with adjusted parameters
22
  input_ids = gpt2_tokenizer.encode(story_prompt, return_tensors='pt')
23
 
24
  story_ids = gpt2_model.generate(
25
  input_ids,
26
+ max_length=450, # Generate slightly shorter but focused stories
27
+ temperature=0.7, # Balanced creativity without too much randomness
28
+ top_p=0.9, # Encourage diversity in output
29
+ top_k=50, # Limit to more probable words
30
+ repetition_penalty=1.2, # Prevent repetitive patterns
31
  num_return_sequences=1
32
  )
33
 
34
  # Decode the generated text
35
  story = gpt2_tokenizer.decode(story_ids[0], skip_special_tokens=True)
36
+
37
+ # Clean up the generated story by removing the initial prompt
38
+ cleaned_response = story.replace(story_prompt, "").strip()
39
+
40
+ return cleaned_response
41
+
42
 
43
  def generate_response(user_input):
44
  # Limit user input length to prevent overflow issues