apratim24 commited on
Commit
9f6cc8d
·
verified ·
1 Parent(s): cc1ca46

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -5
app.py CHANGED
@@ -1,13 +1,13 @@
1
  import gradio as gr
 
2
  from transformers import pipeline
3
  from transformers import AutoTokenizer, ViTFeatureExtractor, VisionEncoderDecoderModel
4
 
5
  import os
6
- api_key = os.getenv("OPENAI_API_KEY")
7
- client = OpenAI(api_key=api_key)
8
 
9
  # Load text generation model
10
- text_generation_model = pipeline("text-generation", model="openai-community/gpt2-large")
11
  # text_generation_model = pipeline("text-generation", model="distilbert/distilgpt2")
12
 
13
  # Load image captioning model
@@ -33,9 +33,10 @@ def generate_story(image, theme, genre):
33
  caption_text = tokenizer.batch_decode(caption_ids, skip_special_tokens=True)[0]
34
 
35
  # Generate story based on the caption
36
- story_prompt = f"Write an interesting {theme} story in the {genre} genre. The story should be about {caption_text}."
37
 
38
- story = text_generation_model(story_prompt, max_length=150)[0]["generated_text"]
 
39
 
40
  return story
41
  except Exception as e:
 
1
  import gradio as gr
2
+ from langchain_openai import OpenAI
3
  from transformers import pipeline
4
  from transformers import AutoTokenizer, ViTFeatureExtractor, VisionEncoderDecoderModel
5
 
6
  import os
7
+ openai_api_key = os.getenv("OPENAI_API_KEY")
 
8
 
9
  # Load text generation model
10
+ # text_generation_model = pipeline("text-generation", model="openai-community/gpt2-large")
11
  # text_generation_model = pipeline("text-generation", model="distilbert/distilgpt2")
12
 
13
  # Load image captioning model
 
33
  caption_text = tokenizer.batch_decode(caption_ids, skip_special_tokens=True)[0]
34
 
35
  # Generate story based on the caption
36
+ story_prompt = f"Write an interesting {theme} story in the {genre} genre. The story should be within 100 words about {caption_text}."
37
 
38
+ llm = OpenAI(model_name="gpt-3.5-turbo-instruct", openai_api_key=openai_api_key)
39
+ story = llm.invoke(story_prompt)
40
 
41
  return story
42
  except Exception as e: