Saim8250 commited on
Commit
db3dea2
·
verified ·
1 Parent(s): 757829a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -17
app.py CHANGED
@@ -1,25 +1,45 @@
1
  import gradio as gr
2
- from transformers import pipeline
 
 
3
 
4
- pipeline = pipeline(task="image-segmentation", model="Saim8250/litemedsam")
 
 
 
5
 
6
- def predict(image_path):
7
- # Perform image segmentation
8
- predictions = pipe(image_path)
 
 
 
 
 
 
 
9
 
10
- # Access the segmented mask or any other relevant information in predictions
11
- segmented_mask = predictions["segmentation_mask"]
 
 
12
 
13
- # Convert the segmentation mask to an RGB image
14
- segmented_image = colorize_mask(segmented_mask)
 
15
 
 
 
 
 
16
  return segmented_image
17
 
18
-
19
- gr.Interface(
20
- predict,
21
- inputs=gr.inputs.Image(label="Upload image", type="filepath"),
22
- outputs=gr.outputs.Label(num_top_classes=2),
23
- title="segmentation",
24
- allow_flagging="manual"
25
- ).launch()
 
 
1
  import gradio as gr
2
+ import torch
3
+ from torchvision import transforms
4
+ from PIL import Image
5
 
6
+ # Load the segmentation model (replace 'path/to/lightmed_model' with the actual path)
7
+ model_path = 'path/to/lightmed_model'
8
+ segmentation_model = torch.load(model_path, map_location=torch.device('cpu'))
9
+ segmentation_model.eval()
10
 
11
+ # Define the preprocessing function for the input image
12
+ def preprocess(image):
13
+ # Resize the image to match the model's expected input size
14
+ transform = transforms.Compose([
15
+ transforms.Resize((256, 256)),
16
+ transforms.ToTensor(),
17
+ ])
18
+ img = Image.fromarray(image)
19
+ img = transform(img).unsqueeze(0)
20
+ return img
21
 
22
+ # Define the segmentation function
23
+ def segment_image(input_image):
24
+ # Preprocess the input image
25
+ input_tensor = preprocess(input_image)
26
 
27
+ # Perform segmentation using the model
28
+ with torch.no_grad():
29
+ output = segmentation_model(input_tensor)
30
 
31
+ # Convert the output tensor to a segmented image
32
+ segmented_image = torch.argmax(output, dim=1).squeeze().numpy()
33
+
34
+ # Return the segmented image
35
  return segmented_image
36
 
37
+ # Define the Gradio interface
38
+ iface = gr.Interface(
39
+ fn=segment_image,
40
+ inputs=gr.Image(type="pil", preprocess=preprocess),
41
+ outputs=gr.Image(type="numpy")
42
+ )
43
+
44
+ # Launch the Gradio app
45
+ iface.launch()