xzuyn commited on
Commit
948a4f1
·
verified ·
1 Parent(s): fa25ed4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -116
app.py CHANGED
@@ -1,124 +1,75 @@
1
- import gradio as gr
2
- from transformers import AutoProcessor, AutoModelForCausalLM
3
- import re
4
- from PIL import Image
5
  import os
6
- import numpy as np
7
- import spaces
8
  import subprocess
 
 
 
9
  import torch
 
10
 
11
 
12
- subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
 
 
 
13
 
14
- model = AutoModelForCausalLM.from_pretrained(
15
- 'PJMixers-Images/Florence-2-base-Castollux-v0.5',
16
- trust_remote_code=True
17
- ).eval()
18
- processor = AutoProcessor.from_pretrained(
19
- 'PJMixers-Images/Florence-2-base-Castollux-v0.5',
20
- trust_remote_code=True
21
- )
22
 
23
  TITLE = "# [PJMixers-Images/Florence-2-base-Castollux-v0.5](https://huggingface.co/PJMixers-Images/Florence-2-base-Castollux-v0.5)"
24
 
25
 
26
- @spaces.GPU
27
  def process_image(image):
28
- if isinstance(image, np.ndarray):
29
- image = Image.fromarray(image)
30
- elif isinstance(image, str):
31
- image = Image.open(image)
32
- if image.mode != "RGB":
33
- image = image.convert("RGB")
34
-
35
- inputs = processor(text="<CAPTION>", images=image, return_tensors="pt")
36
- generated_ids = model.generate(
37
- input_ids=inputs["input_ids"],
38
- pixel_values=inputs["pixel_values"],
39
- max_new_tokens=1024,
40
- num_beams=5,
41
- do_sample=True
42
- )
43
- generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
44
-
45
- return processor.post_process_generation(generated_text, task="<CAPTION>", image_size=(image.width, image.height))
46
-
47
-
48
- def extract_frames(image_path, output_folder):
49
- with Image.open(image_path) as img:
50
- base_name = os.path.splitext(os.path.basename(image_path))[0]
51
- frame_paths = []
52
-
53
- try:
54
- for i in range(0, img.n_frames):
55
- img.seek(i)
56
- frame_path = os.path.join(output_folder, f"{base_name}_frame_{i:03d}.png")
57
- img.save(frame_path)
58
- frame_paths.append(frame_path)
59
- except EOFError:
60
- pass # We've reached the end of the sequence
61
-
62
- return frame_paths
63
-
64
-
65
- def process_folder(folder_path):
66
- if not os.path.isdir(folder_path):
67
- return "Invalid folder path."
68
-
69
- processed_files = []
70
- skipped_files = []
71
- for filename in os.listdir(folder_path):
72
- if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp', '.webp', '.heic')):
73
- image_path = os.path.join(folder_path, filename)
74
- txt_filename = os.path.splitext(filename)[0] + '.txt'
75
- txt_path = os.path.join(folder_path, txt_filename)
76
-
77
- # Check if the corresponding text file already exists
78
- if os.path.exists(txt_path):
79
- skipped_files.append(f"Skipped {filename} (text file already exists)")
80
- continue
81
-
82
- # Check if the image has multiple frames
83
- with Image.open(image_path) as img:
84
- if getattr(img, "is_animated", False) and img.n_frames > 1:
85
- # Extract frames
86
- frames = extract_frames(image_path, folder_path)
87
- for frame_path in frames:
88
- frame_txt_filename = os.path.splitext(os.path.basename(frame_path))[0] + '.txt'
89
- frame_txt_path = os.path.join(folder_path, frame_txt_filename)
90
-
91
- # Check if the corresponding text file for the frame already exists
92
- if os.path.exists(frame_txt_path):
93
- skipped_files.append(f"Skipped {os.path.basename(frame_path)} (text file already exists)")
94
- continue
95
-
96
- caption = process_image(frame_path)
97
-
98
- with open(frame_txt_path, 'w', encoding='utf-8') as f:
99
- f.write(caption)
100
-
101
- processed_files.append(f"Processed {os.path.basename(frame_path)} -> {frame_txt_filename}")
102
- else:
103
- # Process single image
104
- caption = process_image(image_path)
105
-
106
- with open(txt_path, 'w', encoding='utf-8') as f:
107
- f.write(caption)
108
-
109
- processed_files.append(f"Processed {filename} -> {txt_filename}")
110
-
111
- result = "\n".join(processed_files + skipped_files)
112
-
113
- return result if result else "No image files found or all files were skipped in the specified folder."
114
 
 
 
115
  css = """
116
  #output { height: 500px; overflow: auto; border: 1px solid #ccc; }
117
  """
118
 
119
  with gr.Blocks(css=css) as demo:
120
  gr.Markdown(TITLE)
121
-
122
  with gr.Tab(label="Single Image Processing"):
123
  with gr.Row():
124
  with gr.Column():
@@ -126,7 +77,7 @@ with gr.Blocks(css=css) as demo:
126
  submit_btn = gr.Button(value="Submit")
127
  with gr.Column():
128
  output_text = gr.Textbox(label="Output Text")
129
-
130
  gr.Examples(
131
  [
132
  ["eval_img_1.jpg"],
@@ -136,22 +87,15 @@ with gr.Blocks(css=css) as demo:
136
  ["eval_img_5.jpg"],
137
  ["eval_img_6.jpg"],
138
  ["eval_img_7.png"],
139
- ["eval_img_8.jpg"]
140
  ],
141
  inputs=[input_img],
142
  outputs=[output_text],
143
  fn=process_image,
144
- label='Try captioning on below examples'
145
  )
146
-
147
- submit_btn.click(process_image, [input_img], [output_text])
148
 
149
- with gr.Tab(label="Batch Processing"):
150
- with gr.Row():
151
- folder_input = gr.Textbox(label="Input Folder Path")
152
- batch_submit_btn = gr.Button(value="Process Folder")
153
- batch_output = gr.Textbox(label="Batch Processing Results", lines=10)
154
-
155
- batch_submit_btn.click(process_folder, [folder_input], [batch_output])
156
 
157
- demo.launch(debug=True)
 
 
 
 
 
 
1
  import os
2
+ import re
 
3
  import subprocess
4
+ import numpy as np
5
+ from PIL import Image
6
+ import gradio as gr
7
  import torch
8
+ from transformers import AutoProcessor, AutoModelForCausalLM
9
 
10
 
11
+ # Load model and processor, enabling trust_remote_code if needed
12
+ model_name = "PJMixers-Images/Florence-2-base-Castollux-v0.5"
13
+ model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True).eval()
14
+ processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
15
 
16
+ # Set device (GPU if available)
17
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
+ model.to(device)
 
 
 
 
 
19
 
20
  TITLE = "# [PJMixers-Images/Florence-2-base-Castollux-v0.5](https://huggingface.co/PJMixers-Images/Florence-2-base-Castollux-v0.5)"
21
 
22
 
 
23
  def process_image(image):
24
+ """
25
+ Process a single image to generate a caption.
26
+ Supports image input as file path, numpy array, or PIL Image.
27
+ """
28
+ try:
29
+ # Convert input to PIL image if necessary
30
+ if isinstance(image, np.ndarray):
31
+ image = Image.fromarray(image)
32
+ elif isinstance(image, str):
33
+ image = Image.open(image)
34
+ if image.mode != "RGB":
35
+ image = image.convert("RGB")
36
+
37
+ # Prepare inputs for the model
38
+ inputs = processor(text="<CAPTION>", images=image, return_tensors="pt")
39
+ # Move tensors to the appropriate device
40
+ inputs = {k: v.to(device) for k, v in inputs.items()}
41
+
42
+ # Disable gradients during inference
43
+ with torch.no_grad():
44
+ generated_ids = model.generate(
45
+ input_ids=inputs["input_ids"],
46
+ pixel_values=inputs["pixel_values"],
47
+ max_new_tokens=1024,
48
+ num_beams=5,
49
+ do_sample=True,
50
+ )
51
+
52
+ # Decode and post-process the generated text
53
+ generated_text = processor.batch_decode(
54
+ generated_ids, skip_special_tokens=False
55
+ )[0]
56
+ caption = processor.post_process_generation(
57
+ generated_text, task="<CAPTION>", image_size=(image.width, image.height)
58
+ )
59
+ return caption
60
+
61
+ except Exception as e:
62
+ return f"Error processing image: {e}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
+
65
+ # Custom CSS to style the output box
66
  css = """
67
  #output { height: 500px; overflow: auto; border: 1px solid #ccc; }
68
  """
69
 
70
  with gr.Blocks(css=css) as demo:
71
  gr.Markdown(TITLE)
72
+
73
  with gr.Tab(label="Single Image Processing"):
74
  with gr.Row():
75
  with gr.Column():
 
77
  submit_btn = gr.Button(value="Submit")
78
  with gr.Column():
79
  output_text = gr.Textbox(label="Output Text")
80
+
81
  gr.Examples(
82
  [
83
  ["eval_img_1.jpg"],
 
87
  ["eval_img_5.jpg"],
88
  ["eval_img_6.jpg"],
89
  ["eval_img_7.png"],
90
+ ["eval_img_8.jpg"],
91
  ],
92
  inputs=[input_img],
93
  outputs=[output_text],
94
  fn=process_image,
95
+ label="Try captioning on below examples",
96
  )
 
 
97
 
98
+ submit_btn.click(process_image, [input_img], [output_text])
 
 
 
 
 
 
99
 
100
+ if __name__ == "__main__":
101
+ demo.launch(debug=True)