zhiqiulin commited on
Commit
3da86ac
·
verified ·
1 Parent(s): 412ada8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +103 -41
app.py CHANGED
@@ -2,43 +2,29 @@ import spaces
2
  import gradio as gr
3
  import torch
4
  torch.jit.script = lambda f: f # Avoid script error in lambda
5
- from t2v_metrics import VQAScore, list_all_vqascore_models
 
6
 
 
 
7
 
8
- def update_model(model_name):
 
 
9
  return VQAScore(model=model_name, device="cuda")
10
 
11
- # Use global variables for model pipe and current model name
12
- global model_pipe, cur_model_name
13
- cur_model_name = "clip-flant5-xl"
14
- model_pipe = update_model(cur_model_name)
15
-
16
-
17
- # Ensure GPU context manager is imported correctly (assuming spaces is a module you have)
18
- #try:
19
- #from spaces import GPU # i believe this is wrong, spaces package does not have "GPU"
20
- #except ImportError:
21
- # GPU = lambda duration: (lambda f: f) # Dummy decorator if spaces.GPU is not available
22
-
23
- if torch.cuda.is_available():
24
- model_pipe.device = "cuda"
25
- else:
26
- print("CUDA is not available")
27
-
28
- @spaces.GPU # a duration lower than 60 does not work, leave as is.
29
  def generate(model_name, image, text):
30
- global model_pipe, cur_model_name
31
-
32
- if model_name != cur_model_name:
33
- cur_model_name = model_name # Update the current model name
34
- model_pipe = update_model(model_name)
35
 
36
- print("Image:", image) # Debug: Print image path
37
- print("Text:", text) # Debug: Print text input
38
  print("Using model:", model_name)
39
 
40
  try:
41
- result = model_pipe(images=[image], texts=[text]).cpu()[0][0].item() # Perform the model inference
 
42
  print("Result:", result)
43
  except RuntimeError as e:
44
  print(f"RuntimeError during model inference: {e}")
@@ -46,30 +32,106 @@ def generate(model_name, image, text):
46
 
47
  return result
48
 
49
-
50
  def rank_images(model_name, images, text):
51
- global model_pipe, cur_model_name
52
-
53
- if model_name != cur_model_name:
54
- cur_model_name = model_name # Update the current model name
55
- model_pipe = update_model(model_name)
56
 
57
  images = [image_tuple[0] for image_tuple in images]
58
- print("Images:", images) # Debug: Print image paths
59
- print("Text:", text) # Debug: Print text input
60
  print("Using model:", model_name)
61
 
62
  try:
63
- results = model_pipe(images=images, texts=[text]).cpu()[:, 0].tolist() # Perform the model inference on all images
64
- print("Initial results: should be imgs x texts", results)
65
- ranked_results = sorted(zip(images, results), key=lambda x: x[1], reverse=True) # Rank results
66
- ranked_images = [(img, f"Rank: {rank + 1} - Score: {score:.2f}") for rank, (img, score) in enumerate(ranked_results)] # Pair images with their scores and rank
 
 
 
 
 
 
67
  print("Ranked Results:", ranked_results)
68
  except RuntimeError as e:
69
  print(f"RuntimeError during model inference: {e}")
70
  raise e
71
 
72
  return ranked_images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
 
75
  ### EXAMPLES ###
@@ -190,4 +252,4 @@ with gr.Blocks() as demo_vqascore_ranking:
190
 
191
  # Launch the interface
192
  demo_vqascore_ranking.queue()
193
- demo_vqascore_ranking.launch(share=False)
 
2
  import gradio as gr
3
  import torch
4
  torch.jit.script = lambda f: f # Avoid script error in lambda
5
+ from t2v_metrics import VQAScore
6
+ from functools import lru_cache
7
 
8
+ # Remove any global model loading or CUDA initialization
9
+ # Do not call torch.cuda.is_available() at the global scope
10
 
11
+ @lru_cache()
12
+ def get_model(model_name):
13
+ # This function will cache the model per process
14
  return VQAScore(model=model_name, device="cuda")
15
 
16
+ @spaces.GPU # Decorate the function to use GPU
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  def generate(model_name, image, text):
18
+ # Load the model inside the GPU context
19
+ model_pipe = get_model(model_name)
 
 
 
20
 
21
+ print("Image:", image)
22
+ print("Text:", text)
23
  print("Using model:", model_name)
24
 
25
  try:
26
+ # Perform the model inference
27
+ result = model_pipe(images=[image], texts=[text]).cpu()[0][0].item()
28
  print("Result:", result)
29
  except RuntimeError as e:
30
  print(f"RuntimeError during model inference: {e}")
 
32
 
33
  return result
34
 
35
+ @spaces.GPU # Decorate the function to use GPU
36
  def rank_images(model_name, images, text):
37
+ # Load the model inside the GPU context
38
+ model_pipe = get_model(model_name)
 
 
 
39
 
40
  images = [image_tuple[0] for image_tuple in images]
41
+ print("Images:", images)
42
+ print("Text:", text)
43
  print("Using model:", model_name)
44
 
45
  try:
46
+ # Perform the model inference on all images
47
+ results = model_pipe(images=images, texts=[text]).cpu()[:, 0].tolist()
48
+ print("Initial results:", results)
49
+ # Rank results
50
+ ranked_results = sorted(zip(images, results), key=lambda x: x[1], reverse=True)
51
+ # Pair images with their scores and rank
52
+ ranked_images = [
53
+ (img, f"Rank: {rank + 1} - Score: {score:.2f}")
54
+ for rank, (img, score) in enumerate(ranked_results)
55
+ ]
56
  print("Ranked Results:", ranked_results)
57
  except RuntimeError as e:
58
  print(f"RuntimeError during model inference: {e}")
59
  raise e
60
 
61
  return ranked_images
62
+
63
+ # import spaces
64
+ # import gradio as gr
65
+ # import torch
66
+ # torch.jit.script = lambda f: f # Avoid script error in lambda
67
+ # from t2v_metrics import VQAScore, list_all_vqascore_models
68
+
69
+
70
+ # def update_model(model_name):
71
+ # return VQAScore(model=model_name, device="cuda")
72
+
73
+ # # Use global variables for model pipe and current model name
74
+ # global model_pipe, cur_model_name
75
+ # cur_model_name = "clip-flant5-xl"
76
+ # model_pipe = update_model(cur_model_name)
77
+
78
+
79
+ # # Ensure GPU context manager is imported correctly (assuming spaces is a module you have)
80
+ # #try:
81
+ # #from spaces import GPU # i believe this is wrong, spaces package does not have "GPU"
82
+ # #except ImportError:
83
+ # # GPU = lambda duration: (lambda f: f) # Dummy decorator if spaces.GPU is not available
84
+
85
+ # if torch.cuda.is_available():
86
+ # model_pipe.device = "cuda"
87
+ # else:
88
+ # print("CUDA is not available")
89
+
90
+ # @spaces.GPU # a duration lower than 60 does not work, leave as is.
91
+ # def generate(model_name, image, text):
92
+ # global model_pipe, cur_model_name
93
+
94
+ # if model_name != cur_model_name:
95
+ # cur_model_name = model_name # Update the current model name
96
+ # model_pipe = update_model(model_name)
97
+
98
+ # print("Image:", image) # Debug: Print image path
99
+ # print("Text:", text) # Debug: Print text input
100
+ # print("Using model:", model_name)
101
+
102
+ # try:
103
+ # result = model_pipe(images=[image], texts=[text]).cpu()[0][0].item() # Perform the model inference
104
+ # print("Result:", result)
105
+ # except RuntimeError as e:
106
+ # print(f"RuntimeError during model inference: {e}")
107
+ # raise e
108
+
109
+ # return result
110
+
111
+
112
+ # def rank_images(model_name, images, text):
113
+ # global model_pipe, cur_model_name
114
+
115
+ # if model_name != cur_model_name:
116
+ # cur_model_name = model_name # Update the current model name
117
+ # model_pipe = update_model(model_name)
118
+
119
+ # images = [image_tuple[0] for image_tuple in images]
120
+ # print("Images:", images) # Debug: Print image paths
121
+ # print("Text:", text) # Debug: Print text input
122
+ # print("Using model:", model_name)
123
+
124
+ # try:
125
+ # results = model_pipe(images=images, texts=[text]).cpu()[:, 0].tolist() # Perform the model inference on all images
126
+ # print("Initial results: should be imgs x texts", results)
127
+ # ranked_results = sorted(zip(images, results), key=lambda x: x[1], reverse=True) # Rank results
128
+ # ranked_images = [(img, f"Rank: {rank + 1} - Score: {score:.2f}") for rank, (img, score) in enumerate(ranked_results)] # Pair images with their scores and rank
129
+ # print("Ranked Results:", ranked_results)
130
+ # except RuntimeError as e:
131
+ # print(f"RuntimeError during model inference: {e}")
132
+ # raise e
133
+
134
+ # return ranked_images
135
 
136
 
137
  ### EXAMPLES ###
 
252
 
253
  # Launch the interface
254
  demo_vqascore_ranking.queue()
255
+ demo_vqascore_ranking.launch(share=True)