Alex commited on
Commit
077a679
·
1 Parent(s): f6210c2

added new endpoint

Browse files
Files changed (2) hide show
  1. README.md +5 -1
  2. app.py +57 -0
README.md CHANGED
@@ -37,4 +37,8 @@ ir
37
 
38
  curl -X POST "http://localhost:7860/segment" \
39
  -H "Content-Type: application/json" \
40
- -d "{\"image_base64\": \"$(base64 woman_with_bag.jpeg)\"}"
 
 
 
 
 
37
 
38
  curl -X POST "http://localhost:7860/segment" \
39
  -H "Content-Type: application/json" \
40
+ -d "{\"image_base64\": \"$(base64 woman_with_bag.jpeg)\"}"
41
+
42
+
43
+ # Output
44
+ {"mask":"data:image/png;base64...","annotations":{"mask":[[]]"label":"fashion"}}
app.py CHANGED
@@ -1,15 +1,22 @@
1
  from fastapi import FastAPI, File, UploadFile, HTTPException
2
  from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor
3
  import torch
 
4
  from PIL import Image
5
  import numpy as np
6
  import io
7
  import base64
8
  import logging
 
 
9
 
10
  # Inizializza l'app FastAPI
11
  app = FastAPI()
12
 
 
 
 
 
13
  # Configura il logging
14
  logging.basicConfig(level=logging.INFO)
15
  logger = logging.getLogger(__name__)
@@ -74,6 +81,56 @@ async def segment_endpoint(file: UploadFile = File(...)):
74
  logger.error(f"Errore nell'endpoint: {str(e)}")
75
  raise HTTPException(status_code=500, detail=f"Errore nell'elaborazione: {str(e)}")
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  # Per compatibilità con Hugging Face Spaces
78
  if __name__ == "__main__":
79
  import uvicorn
 
1
  from fastapi import FastAPI, File, UploadFile, HTTPException
2
  from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor
3
  import torch
4
+ from pydantic import BaseModel
5
  from PIL import Image
6
  import numpy as np
7
  import io
8
  import base64
9
  import logging
10
+ import requests
11
+ import torch.nn as nn
12
 
13
  # Inizializza l'app FastAPI
14
  app = FastAPI()
15
 
16
+ # Add this class for the request body
17
+ class ImageURL(BaseModel):
18
+ url: str
19
+
20
  # Configura il logging
21
  logging.basicConfig(level=logging.INFO)
22
  logger = logging.getLogger(__name__)
 
81
  logger.error(f"Errore nell'endpoint: {str(e)}")
82
  raise HTTPException(status_code=500, detail=f"Errore nell'elaborazione: {str(e)}")
83
 
84
+
85
+
86
+ # Add new endpoint
87
+ @app.post("/segment-url")
88
+ async def segment_url_endpoint(image_data: ImageURL):
89
+ try:
90
+ logger.info("Downloading image from URL...")
91
+ response = requests.get(image_data.url, stream=True)
92
+ if response.status_code != 200:
93
+ raise HTTPException(status_code=400, detail="Could not download image from URL")
94
+
95
+ # Open image from URL
96
+ image = Image.open(response.raw).convert("RGB")
97
+
98
+ # Process image with SegFormer
99
+ logger.info("Processing image...")
100
+ inputs = processor(images=image, return_tensors="pt")
101
+ outputs = model(**inputs)
102
+ logits = outputs.logits.cpu()
103
+
104
+ # Upsample logits to match original image size
105
+ upsampled_logits = nn.functional.interpolate(
106
+ logits,
107
+ size=image.size[::-1],
108
+ mode="bilinear",
109
+ align_corners=False,
110
+ )
111
+
112
+ # Get prediction
113
+ pred_seg = upsampled_logits.argmax(dim=1)[0]
114
+
115
+ # Convert to image
116
+ mask_img = Image.fromarray((pred_seg.numpy() * 255).astype(np.uint8))
117
+
118
+ # Convert to base64
119
+ buffered = io.BytesIO()
120
+ mask_img.save(buffered, format="PNG")
121
+ mask_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
122
+
123
+ return {
124
+ "mask": f"data:image/png;base64,{mask_base64}",
125
+ "size": image.size,
126
+ "labels" : pred_seg
127
+ }
128
+
129
+ except Exception as e:
130
+ logger.error(f"Error processing URL: {str(e)}")
131
+ raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
132
+
133
+
134
  # Per compatibilità con Hugging Face Spaces
135
  if __name__ == "__main__":
136
  import uvicorn