tournas commited on
Commit
0b47b5a
·
verified ·
1 Parent(s): 483fc16

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -5
app.py CHANGED
@@ -22,19 +22,28 @@ if not api_key:
22
  client = OpenAI(api_key=api_key)
23
 
24
  yolo_model = YOLO("yolov8s.pt")
25
- stable_diffusion = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
 
 
 
 
26
  stable_diffusion.to(device)
27
  nltk.download("punkt")
28
 
29
- summarizer = pipeline("summarization", model="facebook/bart-large-cnn", device=0 if torch.cuda.is_available() else -1)
 
 
 
 
30
 
31
  @spaces.GPU
32
- def detect_objects(image_path):
33
- results = yolo_model(image_path)
 
34
  detected_objects = []
35
  for r in results:
36
  for box in r.boxes:
37
- class_id = int(box.cls.item()) # Διορθώθηκε
38
  label = yolo_model.names[class_id]
39
  detected_objects.append(label)
40
  return detected_objects
 
22
  client = OpenAI(api_key=api_key)
23
 
24
  yolo_model = YOLO("yolov8s.pt")
25
+ stable_diffusion = StableDiffusionPipeline.from_pretrained(
26
+ "runwayml/stable-diffusion-v1-5",
27
+ torch_dtype=torch.float16,
28
+ safety_checker=None
29
+ ).to(device)
30
  stable_diffusion.to(device)
31
  nltk.download("punkt")
32
 
33
+ summarizer = pipeline(
34
+ "summarization",
35
+ model="facebook/bart-large-cnn",
36
+ device_map="auto"
37
+ )
38
 
39
  @spaces.GPU
40
+ def detect_objects(image):
41
+ image_array = np.array(image)
42
+ results = yolo_model(image_array)
43
  detected_objects = []
44
  for r in results:
45
  for box in r.boxes:
46
+ class_id = int(box.cls.item())
47
  label = yolo_model.names[class_id]
48
  detected_objects.append(label)
49
  return detected_objects