Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -22,19 +22,28 @@ if not api_key:
|
|
22 |
client = OpenAI(api_key=api_key)
|
23 |
|
24 |
yolo_model = YOLO("yolov8s.pt")
|
25 |
-
stable_diffusion = StableDiffusionPipeline.from_pretrained(
|
|
|
|
|
|
|
|
|
26 |
stable_diffusion.to(device)
|
27 |
nltk.download("punkt")
|
28 |
|
29 |
-
summarizer = pipeline(
|
|
|
|
|
|
|
|
|
30 |
|
31 |
@spaces.GPU
|
32 |
-
def detect_objects(
|
33 |
-
|
|
|
34 |
detected_objects = []
|
35 |
for r in results:
|
36 |
for box in r.boxes:
|
37 |
-
class_id = int(box.cls.item())
|
38 |
label = yolo_model.names[class_id]
|
39 |
detected_objects.append(label)
|
40 |
return detected_objects
|
|
|
22 |
client = OpenAI(api_key=api_key)
|
23 |
|
24 |
yolo_model = YOLO("yolov8s.pt")
|
25 |
+
stable_diffusion = StableDiffusionPipeline.from_pretrained(
|
26 |
+
"runwayml/stable-diffusion-v1-5",
|
27 |
+
torch_dtype=torch.float16,
|
28 |
+
safety_checker=None
|
29 |
+
).to(device)
|
30 |
stable_diffusion.to(device)
|
31 |
nltk.download("punkt")
|
32 |
|
33 |
+
summarizer = pipeline(
|
34 |
+
"summarization",
|
35 |
+
model="facebook/bart-large-cnn",
|
36 |
+
device_map="auto"
|
37 |
+
)
|
38 |
|
39 |
@spaces.GPU
|
40 |
+
def detect_objects(image):
|
41 |
+
image_array = np.array(image)
|
42 |
+
results = yolo_model(image_array)
|
43 |
detected_objects = []
|
44 |
for r in results:
|
45 |
for box in r.boxes:
|
46 |
+
class_id = int(box.cls.item())
|
47 |
label = yolo_model.names[class_id]
|
48 |
detected_objects.append(label)
|
49 |
return detected_objects
|