tournas commited on
Commit
dd0e420
·
verified ·
1 Parent(s): b03e0bc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -12
app.py CHANGED
@@ -10,7 +10,6 @@ from gtts import gTTS
10
  from PIL import Image
11
  import numpy as np
12
  from nltk.tokenize import sent_tokenize
13
- from IPython.display import Audio
14
  import spaces
15
 
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -22,22 +21,21 @@ if not api_key:
22
  client = OpenAI(api_key=api_key)
23
 
24
  yolo_model = YOLO("yolov8s.pt")
 
25
  stable_diffusion = StableDiffusionPipeline.from_pretrained(
26
  "runwayml/stable-diffusion-v1-5",
27
- torch_dtype=torch.float16,
28
- safety_checker=None
29
  ).to(device)
30
- stable_diffusion.to(device)
31
  nltk.download("punkt")
32
 
33
  summarizer = pipeline(
34
  "summarization",
35
  model="sshleifer/distilbart-cnn-6-6"
36
  )
37
-
38
  @spaces.GPU
39
  def detect_objects(image):
40
- image_array = np.array(image)
41
  results = yolo_model(image_array)
42
  detected_objects = []
43
  for r in results:
@@ -49,12 +47,12 @@ def detect_objects(image):
49
 
50
  def generate_story(detected_objects):
51
  story_prompt = f"Write a short story based on the following objects: {', '.join(detected_objects)}"
52
- response = client.completions.create(
53
  model="gpt-4o-mini",
54
  messages=[{"role": "user", "content": story_prompt}],
55
  max_tokens=200
56
  )
57
- return response.choices[0].text.strip() # Διορθώθηκε
58
 
59
  def summarize_story(story):
60
  summary = summarizer(story, max_length=100, do_sample=False)[0]['summary_text']
@@ -66,7 +64,7 @@ def generate_images(story):
66
  prompts = [f"Highly detailed, cinematic scene: {scene}, digital art, 4K, realistic lighting" for scene in scenes]
67
  images = []
68
  for prompt in prompts:
69
- image = stable_diffusion(prompt=prompt).images[0] # Διορθώθηκε
70
  images.append(image)
71
  return images
72
 
@@ -77,9 +75,7 @@ def text_to_speech(story):
77
  return audio_file_path
78
 
79
  def full_pipeline(image):
80
- image_path = "input.jpg"
81
- image.save(image_path) # Διορθώθηκε
82
- detected_objects = detect_objects(image_path)
83
  story = generate_story(detected_objects)
84
  scenes = summarize_story(story)
85
  images = generate_images(story)
@@ -103,3 +99,4 @@ demo = gr.Interface(
103
 
104
  if __name__ == "__main__":
105
  demo.launch()
 
 
10
  from PIL import Image
11
  import numpy as np
12
  from nltk.tokenize import sent_tokenize
 
13
  import spaces
14
 
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
21
  client = OpenAI(api_key=api_key)
22
 
23
  yolo_model = YOLO("yolov8s.pt")
24
+
25
  stable_diffusion = StableDiffusionPipeline.from_pretrained(
26
  "runwayml/stable-diffusion-v1-5",
27
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
 
28
  ).to(device)
29
+
30
  nltk.download("punkt")
31
 
32
  summarizer = pipeline(
33
  "summarization",
34
  model="sshleifer/distilbart-cnn-6-6"
35
  )
 
36
  @spaces.GPU
37
  def detect_objects(image):
38
+ image_array = np.array(image) # Μετατροπή PIL → NumPy
39
  results = yolo_model(image_array)
40
  detected_objects = []
41
  for r in results:
 
47
 
48
  def generate_story(detected_objects):
49
  story_prompt = f"Write a short story based on the following objects: {', '.join(detected_objects)}"
50
+ response = client.chat.completions.create(
51
  model="gpt-4o-mini",
52
  messages=[{"role": "user", "content": story_prompt}],
53
  max_tokens=200
54
  )
55
+ return response.choices[0].message.content.strip()
56
 
57
  def summarize_story(story):
58
  summary = summarizer(story, max_length=100, do_sample=False)[0]['summary_text']
 
64
  prompts = [f"Highly detailed, cinematic scene: {scene}, digital art, 4K, realistic lighting" for scene in scenes]
65
  images = []
66
  for prompt in prompts:
67
+ image = stable_diffusion(prompt).images[0] # Διόρθωση
68
  images.append(image)
69
  return images
70
 
 
75
  return audio_file_path
76
 
77
  def full_pipeline(image):
78
+ detected_objects = detect_objects(image)
 
 
79
  story = generate_story(detected_objects)
80
  scenes = summarize_story(story)
81
  images = generate_images(story)
 
99
 
100
  if __name__ == "__main__":
101
  demo.launch()
102
+