Add custom handler.py for Inference Endpoints

#1
.gitattributes CHANGED
@@ -33,4 +33,3 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
- *.png filter=lfs diff=lfs merge=lfs -text
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
Photos/example1.png DELETED

Git LFS Details

  • SHA256: bfb055d5b1340ee52078020b244a2ee0863589cc6752e1ee85bb7b8c77b49f1f
  • Pointer size: 131 Bytes
  • Size of remote file: 464 kB
Photos/example2.png DELETED

Git LFS Details

  • SHA256: b6b16825dbadef586638417e07bed3a4b99b2a5f14a46310c6f84337eb3a1956
  • Pointer size: 131 Bytes
  • Size of remote file: 572 kB
Photos/example3.png DELETED

Git LFS Details

  • SHA256: d4c2bf7b7a343c47861036f58d0e8525751043d40e93314e1c56c343b12aae17
  • Pointer size: 131 Bytes
  • Size of remote file: 521 kB
Photos/example4.png DELETED

Git LFS Details

  • SHA256: 371eaf33581723f9b92a1a8520d7cf583f35d2165cdfbb1b78338643029525d4
  • Pointer size: 131 Bytes
  • Size of remote file: 433 kB
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
  title: Sat2map
3
- emoji: 🗺️
4
  colorFrom: yellow
5
  colorTo: purple
6
  sdk: gradio
 
1
  ---
2
  title: Sat2map
3
+ emoji: 💬
4
  colorFrom: yellow
5
  colorTo: purple
6
  sdk: gradio
app.py CHANGED
@@ -1,118 +1,61 @@
1
  import gradio as gr
2
  import torch
3
  from huggingface_hub import hf_hub_download
 
 
4
  import sys
5
  import os
6
  from PIL import Image
7
  import torchvision.transforms as transforms
8
 
9
-
10
- photos_folder = "Photos"
11
-
12
- # Download model and config
13
  repo_id = "Kiwinicki/sat2map-generator"
14
  generator_path = hf_hub_download(repo_id=repo_id, filename="generator.pth")
 
15
  model_path = hf_hub_download(repo_id=repo_id, filename="model.py")
16
 
17
- # Add path to model
18
  sys.path.append(os.path.dirname(model_path))
19
- from model import Generator, GeneratorConfig
 
 
 
 
 
20
 
21
- # Initialize model
22
- cfg = GeneratorConfig()
23
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
  generator = Generator(cfg).to(device)
25
  generator.load_state_dict(torch.load(generator_path, map_location=device))
26
  generator.eval()
27
 
28
- # Transformations
29
  transform = transforms.Compose([
30
  transforms.Resize((256, 256)),
31
  transforms.ToTensor(),
32
  transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
33
  ])
34
 
35
-
36
  def process_image(image):
37
- if image is None:
38
- return None
39
-
40
- # Convert to tensor
41
  image_tensor = transform(image).unsqueeze(0).to(device)
42
 
43
- # Inference
44
  with torch.no_grad():
45
  output_tensor = generator(image_tensor)
46
 
47
- # Prepare output
48
  output_image = output_tensor.squeeze(0).cpu()
49
- output_image = output_image * 0.5 + 0.5 # Denormalization
50
  output_image = transforms.ToPILImage()(output_image)
 
51
  return output_image
52
 
53
- def load_images_from_folder(folder):
54
- images = []
55
- if not os.path.exists(folder):
56
- os.makedirs(folder)
57
- return images
58
-
59
- for filename in os.listdir(folder):
60
- if filename.lower().endswith(('.png')):
61
- img_path = os.path.join(folder, filename)
62
- try:
63
- img = Image.open(img_path)
64
- images.append((img, filename))
65
- except Exception as e:
66
- print(f"Error loading {filename}: {e}")
67
- return images
68
-
69
- def app():
70
- images = load_images_from_folder(photos_folder)
71
- gallery_images = [img[0] for img in images] if images else []
72
-
73
- with gr.Blocks() as demo:
74
- with gr.Row():
75
- with gr.Column():
76
- input_image = gr.Image(label="Input Image", type="pil")
77
- clear_button = gr.Button("Clear")
78
-
79
- with gr.Column():
80
- gallery = gr.Gallery(
81
- label="Image Gallery",
82
- value=gallery_images,
83
- columns=3,
84
- rows=2,
85
- height="auto"
86
- )
87
-
88
- with gr.Column():
89
- output_image = gr.Image(label="Result Image", type="pil")
90
-
91
- # Handle gallery selection
92
- def on_select(evt: gr.SelectData):
93
- if 0 <= evt.index < len(images):
94
- return images[evt.index][0]
95
- return None
96
-
97
- gallery.select(
98
- fn=on_select,
99
- outputs=input_image
100
- )
101
-
102
- # Process image when input changes
103
- input_image.change(
104
- fn=process_image,
105
- inputs=input_image,
106
- outputs=output_image
107
- )
108
-
109
- # Clear button functionality
110
- clear_button.click(
111
- fn=lambda: None,
112
- outputs=input_image
113
- )
114
-
115
- demo.launch()
116
 
117
- if __name__ == "__main__":
118
- app()
 
1
  import gradio as gr
2
  import torch
3
  from huggingface_hub import hf_hub_download
4
+ import json
5
+ from omegaconf import OmegaConf
6
  import sys
7
  import os
8
  from PIL import Image
9
  import torchvision.transforms as transforms
10
 
11
+ # Pobierz model i config
 
 
 
12
  repo_id = "Kiwinicki/sat2map-generator"
13
  generator_path = hf_hub_download(repo_id=repo_id, filename="generator.pth")
14
+ config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
15
  model_path = hf_hub_download(repo_id=repo_id, filename="model.py")
16
 
17
+ # Dodaj ścieżkę do modelu
18
  sys.path.append(os.path.dirname(model_path))
19
+ from model import Generator
20
+
21
+ # Załaduj konfigurację
22
+ with open(config_path, "r") as f:
23
+ config_dict = json.load(f)
24
+ cfg = OmegaConf.create(config_dict)
25
 
26
+ # Inicjalizacja modelu
 
27
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
  generator = Generator(cfg).to(device)
29
  generator.load_state_dict(torch.load(generator_path, map_location=device))
30
  generator.eval()
31
 
32
+ # Transformacje
33
  transform = transforms.Compose([
34
  transforms.Resize((256, 256)),
35
  transforms.ToTensor(),
36
  transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
37
  ])
38
 
 
39
  def process_image(image):
40
+ # Konwersja do tensora
 
 
 
41
  image_tensor = transform(image).unsqueeze(0).to(device)
42
 
43
+ # Inferencja
44
  with torch.no_grad():
45
  output_tensor = generator(image_tensor)
46
 
47
+ # Przygotowanie wyjścia
48
  output_image = output_tensor.squeeze(0).cpu()
49
+ output_image = output_image * 0.5 + 0.5 # Denormalizacja
50
  output_image = transforms.ToPILImage()(output_image)
51
+
52
  return output_image
53
 
54
+ iface = gr.Interface(
55
+ fn=process_image,
56
+ inputs=gr.Image(type="pil"),
57
+ outputs="image",
58
+ title="Satellite to Map Generator"
59
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
+ iface.launch()
 
requirements.txt CHANGED
@@ -2,5 +2,5 @@ huggingface_hub==0.25.2
2
  torch>=2.0.0
3
  torchvision>=0.15.0
4
  gradio
5
- pillow
6
- pydantic==2.10.6
 
2
  torch>=2.0.0
3
  torchvision>=0.15.0
4
  gradio
5
+ omegaconf
6
+ pillow