tournas commited on
Commit
17a4149
·
verified ·
1 Parent(s): 0ec87e7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -6
app.py CHANGED
@@ -13,7 +13,7 @@ from nltk.tokenize import sent_tokenize
13
  from IPython.display import Audio
14
  import spaces
15
 
16
- device = 'cuda'
17
 
18
  api_key = os.getenv("OPENAI_API_KEY")
19
  if not api_key:
@@ -27,7 +27,7 @@ yolo_model = YOLO("yolov8s.pt")
27
  stable_diffusion = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
28
  stable_diffusion.to(device)
29
  nltk.download("punkt")
30
- summarizer = pipeline("summarization", model="facebook/bart-large-cnn", device=device)
31
 
32
  @spaces.GPU
33
  def detect_objects(image_path):
@@ -47,7 +47,7 @@ def generate_story(detected_objects):
47
  messages=[{"role": "user", "content": story_prompt}],
48
  max_tokens=200
49
  )
50
- return response.choices[0].message.content
51
 
52
  def summarize_story(story):
53
  summary = summarizer(story, max_length=100, do_sample=False)[0]['summary_text']
@@ -59,7 +59,7 @@ def generate_images(story):
59
  prompts = [f"Highly detailed, cinematic scene: {scene}, digital art, 4K, realistic lighting" for scene in scenes]
60
  images = []
61
  for prompt in prompts:
62
- image = stable_diffusion(prompt).images[0]
63
  images.append(image)
64
  return images
65
 
@@ -70,10 +70,12 @@ def text_to_speech(story):
70
  return audio_file_path
71
 
72
  def full_pipeline(image):
73
- detected_objects = detect_objects(image)
 
 
74
  story = generate_story(detected_objects)
75
  scenes = summarize_story(story)
76
- images = generate_images(scenes)
77
  audio = text_to_speech(story)
78
 
79
  return story, scenes, images, audio
 
13
  from IPython.display import Audio
14
  import spaces
15
 
16
+ device = "cuda" if torch.cuda.is_available() else "cpu"
17
 
18
  api_key = os.getenv("OPENAI_API_KEY")
19
  if not api_key:
 
27
  stable_diffusion = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
28
  stable_diffusion.to(device)
29
  nltk.download("punkt")
30
+ summarizer = pipeline("summarization", model="facebook/bart-large-cnn", device= 'cuda' if torch.cuda.is_available() else 'cpu')
31
 
32
  @spaces.GPU
33
  def detect_objects(image_path):
 
47
  messages=[{"role": "user", "content": story_prompt}],
48
  max_tokens=200
49
  )
50
+ return response.choices[0].message.content.strip()
51
 
52
  def summarize_story(story):
53
  summary = summarizer(story, max_length=100, do_sample=False)[0]['summary_text']
 
59
  prompts = [f"Highly detailed, cinematic scene: {scene}, digital art, 4K, realistic lighting" for scene in scenes]
60
  images = []
61
  for prompt in prompts:
62
+ image = stable_diffusion(prompt=prompt).images[0]
63
  images.append(image)
64
  return images
65
 
 
70
  return audio_file_path
71
 
72
  def full_pipeline(image):
73
+ image_path = "input.jpg"
74
+ image.save(image_path)
75
+ detected_objects = detect_objects(image_path)
76
  story = generate_story(detected_objects)
77
  scenes = summarize_story(story)
78
+ images = generate_images(story)
79
  audio = text_to_speech(story)
80
 
81
  return story, scenes, images, audio