Siwon123 commited on
Commit
0b9d903
·
1 Parent(s): c13b290
Files changed (4) hide show
  1. inference.py +4 -4
  2. requirements_colab.txt +520 -0
  3. text_net/DGRN.py +1 -1
  4. text_net/moco.py +2 -2
inference.py CHANGED
@@ -22,7 +22,7 @@ def test_Derain_Dehaze(opt, net, dataset, task="derain"):
22
 
23
  with torch.no_grad():
24
  for ([degraded_name], degradation, degrad_patch, clean_patch, text_prompt) in tqdm(testloader):
25
- degrad_patch, clean_patch = degrad_patch.cuda(), clean_patch.cuda()
26
  restored = net(x_query=degrad_patch, x_key=degrad_patch, text_prompt = text_prompt)
27
 
28
  return save_image_tensor(restored)
@@ -39,10 +39,10 @@ def infer(text_prompt = "", img=None):
39
 
40
  opt = parser.parse_args()
41
  # opt.text_prompt = text_prompt
 
42
 
43
  np.random.seed(0)
44
  torch.manual_seed(0)
45
- torch.cuda.set_device(opt.cuda)
46
 
47
  opt.batch_size = 7
48
  ckpt_path = opt.ckpt_path
@@ -50,9 +50,9 @@ def infer(text_prompt = "", img=None):
50
  derain_set = DerainDehazeDataset(opt, img=img, text_prompt = text_prompt)
51
 
52
  # Make network
53
- net = AirNet(opt).cuda()
54
  net.eval()
55
- net.load_state_dict(torch.load(ckpt_path, map_location=torch.device(opt.cuda)))
56
 
57
  restored = test_Derain_Dehaze(opt, net, derain_set, task="derain")
58
 
 
22
 
23
  with torch.no_grad():
24
  for ([degraded_name], degradation, degrad_patch, clean_patch, text_prompt) in tqdm(testloader):
25
+ degrad_patch, clean_patch = degrad_patch.to(device), clean_patch.to(device)
26
  restored = net(x_query=degrad_patch, x_key=degrad_patch, text_prompt = text_prompt)
27
 
28
  return save_image_tensor(restored)
 
39
 
40
  opt = parser.parse_args()
41
  # opt.text_prompt = text_prompt
42
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
43
 
44
  np.random.seed(0)
45
  torch.manual_seed(0)
 
46
 
47
  opt.batch_size = 7
48
  ckpt_path = opt.ckpt_path
 
50
  derain_set = DerainDehazeDataset(opt, img=img, text_prompt = text_prompt)
51
 
52
  # Make network
53
+ net = AirNet(opt).to(device)
54
  net.eval()
55
+ net.load_state_dict(torch.load(ckpt_path, map_location=device))
56
 
57
  restored = test_Derain_Dehaze(opt, net, derain_set, task="derain")
58
 
requirements_colab.txt ADDED
@@ -0,0 +1,520 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==1.4.0
2
+ accelerate==0.34.2
3
+ addict==2.4.0
4
+ aiohappyeyeballs==2.4.3
5
+ aiohttp==3.10.10
6
+ aiosignal==1.3.1
7
+ alabaster==0.7.16
8
+ albucore==0.0.16
9
+ albumentations==1.4.15
10
+ altair==4.2.2
11
+ annotated-types==0.7.0
12
+ anyio==3.7.1
13
+ argon2-cffi==23.1.0
14
+ argon2-cffi-bindings==21.2.0
15
+ array_record==0.5.1
16
+ arviz==0.19.0
17
+ astropy==6.1.4
18
+ astropy-iers-data==0.2024.10.7.0.32.46
19
+ astunparse==1.6.3
20
+ async-timeout==4.0.3
21
+ atpublic==4.1.0
22
+ attrs==24.2.0
23
+ audioread==3.0.1
24
+ autograd==1.7.0
25
+ babel==2.16.0
26
+ backcall==0.2.0
27
+ beautifulsoup4==4.12.3
28
+ bigframes==1.22.0
29
+ bigquery-magics==0.4.0
30
+ bleach==6.1.0
31
+ blinker==1.4
32
+ blis==0.7.11
33
+ blosc2==2.0.0
34
+ bokeh==3.4.3
35
+ Bottleneck==1.4.0
36
+ bqplot==0.12.43
37
+ branca==0.8.0
38
+ build==1.2.2.post1
39
+ CacheControl==0.14.0
40
+ cachetools==5.5.0
41
+ catalogue==2.0.10
42
+ certifi==2024.8.30
43
+ cffi==1.17.1
44
+ chardet==5.2.0
45
+ charset-normalizer==3.4.0
46
+ chex==0.1.87
47
+ clarabel==0.9.0
48
+ click==8.1.7
49
+ clip @ git+https://github.com/openai/CLIP.git@dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1
50
+ cloudpathlib==0.19.0
51
+ cloudpickle==2.2.1
52
+ cmake==3.30.4
53
+ cmdstanpy==1.2.4
54
+ colab-ssh==0.3.27
55
+ colorcet==3.1.0
56
+ colorlover==0.3.0
57
+ colour==0.1.5
58
+ community==1.0.0b1
59
+ confection==0.1.5
60
+ cons==0.4.6
61
+ contextlib2==21.6.0
62
+ contourpy==1.3.0
63
+ cryptography==43.0.1
64
+ cuda-python==12.2.1
65
+ cudf-cu12 @ https://pypi.nvidia.com/cudf-cu12/cudf_cu12-24.6.1-cp310-cp310-manylinux_2_28_x86_64.whl
66
+ cufflinks==0.17.3
67
+ cupy-cuda12x==12.2.0
68
+ cvxopt==1.3.2
69
+ cvxpy==1.5.3
70
+ cycler==0.12.1
71
+ cymem==2.0.8
72
+ Cython==3.0.11
73
+ dask==2024.8.0
74
+ datascience==0.17.6
75
+ db-dtypes==1.3.0
76
+ dbus-python==1.2.18
77
+ debugpy==1.6.6
78
+ decorator==4.4.2
79
+ defusedxml==0.7.1
80
+ Deprecated==1.2.14
81
+ distributed==2024.8.0
82
+ distro==1.7.0
83
+ dlib==19.24.2
84
+ dm-tree==0.1.8
85
+ docstring_parser==0.16
86
+ docutils==0.18.1
87
+ dopamine_rl==4.0.9
88
+ duckdb==1.1.1
89
+ earthengine-api==1.0.0
90
+ easydict==1.13
91
+ ecos==2.0.14
92
+ editdistance==0.8.1
93
+ eerepr==0.0.4
94
+ einops==0.8.0
95
+ en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl#sha256=86cc141f63942d4b2c5fcee06630fd6f904788d2f0ab005cce45aadb8fb73889
96
+ entrypoints==0.4
97
+ et-xmlfile==1.1.0
98
+ etils==1.9.4
99
+ etuples==0.3.9
100
+ eval_type_backport==0.2.0
101
+ exceptiongroup==1.2.2
102
+ fastai==2.7.17
103
+ fastcore==1.7.13
104
+ fastdownload==0.0.7
105
+ fastjsonschema==2.20.0
106
+ fastprogress==1.0.3
107
+ fastrlock==0.8.2
108
+ filelock==3.16.1
109
+ firebase-admin==6.5.0
110
+ Flask==2.2.5
111
+ flatbuffers==24.3.25
112
+ flax==0.8.5
113
+ folium==0.17.0
114
+ fonttools==4.54.1
115
+ frozendict==2.4.5
116
+ frozenlist==1.4.1
117
+ fsspec==2024.6.1
118
+ ftfy==6.3.0
119
+ future==1.0.0
120
+ gast==0.6.0
121
+ gcsfs==2024.6.1
122
+ GDAL==3.6.4
123
+ gdown==5.2.0
124
+ geemap==0.34.5
125
+ gensim==4.3.3
126
+ geocoder==1.38.1
127
+ geographiclib==2.0
128
+ geopandas==1.0.1
129
+ geopy==2.4.1
130
+ gin-config==0.5.0
131
+ glob2==0.7
132
+ google==2.0.3
133
+ google-ai-generativelanguage==0.6.6
134
+ google-api-core==2.19.2
135
+ google-api-python-client==2.137.0
136
+ google-auth==2.27.0
137
+ google-auth-httplib2==0.2.0
138
+ google-auth-oauthlib==1.2.1
139
+ google-cloud-aiplatform==1.70.0
140
+ google-cloud-bigquery==3.25.0
141
+ google-cloud-bigquery-connection==1.15.5
142
+ google-cloud-bigquery-storage==2.26.0
143
+ google-cloud-bigtable==2.26.0
144
+ google-cloud-core==2.4.1
145
+ google-cloud-datastore==2.19.0
146
+ google-cloud-firestore==2.16.1
147
+ google-cloud-functions==1.16.5
148
+ google-cloud-iam==2.15.2
149
+ google-cloud-language==2.13.4
150
+ google-cloud-pubsub==2.25.0
151
+ google-cloud-resource-manager==1.12.5
152
+ google-cloud-storage==2.8.0
153
+ google-cloud-translate==3.15.5
154
+ google-colab @ file:///colabtools/dist/google_colab-1.0.0.tar.gz
155
+ google-crc32c==1.6.0
156
+ google-generativeai==0.7.2
157
+ google-pasta==0.2.0
158
+ google-resumable-media==2.7.2
159
+ googleapis-common-protos==1.65.0
160
+ googledrivedownloader==0.4
161
+ graphviz==0.20.3
162
+ greenlet==3.1.1
163
+ grpc-google-iam-v1==0.13.1
164
+ grpcio==1.64.1
165
+ grpcio-status==1.48.2
166
+ gspread==6.0.2
167
+ gspread-dataframe==3.3.1
168
+ gym==0.25.2
169
+ gym-notices==0.0.8
170
+ h5netcdf==1.4.0
171
+ h5py==3.11.0
172
+ holidays==0.58
173
+ holoviews==1.19.1
174
+ html5lib==1.1
175
+ httpimport==1.4.0
176
+ httplib2==0.22.0
177
+ huggingface-hub==0.24.7
178
+ humanize==4.10.0
179
+ hyperopt==0.2.7
180
+ ibis-framework==9.2.0
181
+ idna==3.10
182
+ imageio==2.35.1
183
+ imageio-ffmpeg==0.5.1
184
+ imagesize==1.4.1
185
+ imbalanced-learn==0.12.4
186
+ imgaug==0.4.0
187
+ immutabledict==4.2.0
188
+ importlib_metadata==8.5.0
189
+ importlib_resources==6.4.5
190
+ imutils==0.5.4
191
+ inflect==7.4.0
192
+ iniconfig==2.0.0
193
+ intel-cmplr-lib-ur==2024.2.1
194
+ intel-openmp==2024.2.1
195
+ ipyevents==2.0.2
196
+ ipyfilechooser==0.6.0
197
+ ipykernel==5.5.6
198
+ ipyleaflet==0.19.2
199
+ ipyparallel==8.8.0
200
+ ipython==7.34.0
201
+ ipython-genutils==0.2.0
202
+ ipython-sql==0.5.0
203
+ ipytree==0.2.2
204
+ ipywidgets==7.7.1
205
+ itsdangerous==2.2.0
206
+ jax==0.4.33
207
+ jax-cuda12-pjrt==0.4.33
208
+ jax-cuda12-plugin==0.4.33
209
+ jaxlib==0.4.33
210
+ jeepney==0.7.1
211
+ jellyfish==1.1.0
212
+ jieba==0.42.1
213
+ Jinja2==3.1.4
214
+ joblib==1.4.2
215
+ jsonpickle==3.3.0
216
+ jsonschema==4.23.0
217
+ jsonschema-specifications==2024.10.1
218
+ jupyter-client==6.1.12
219
+ jupyter-console==6.1.0
220
+ jupyter-leaflet==0.19.2
221
+ jupyter-server==1.24.0
222
+ jupyter_core==5.7.2
223
+ jupyterlab_pygments==0.3.0
224
+ jupyterlab_widgets==3.0.13
225
+ kaggle==1.6.17
226
+ kagglehub==0.3.1
227
+ keras==3.4.1
228
+ keyring==23.5.0
229
+ kiwisolver==1.4.7
230
+ langcodes==3.4.1
231
+ language_data==1.2.0
232
+ launchpadlib==1.10.16
233
+ lazr.restfulclient==0.14.4
234
+ lazr.uri==1.0.6
235
+ lazy_loader==0.4
236
+ libclang==18.1.1
237
+ librosa==0.10.2.post1
238
+ lightgbm==4.5.0
239
+ linkify-it-py==2.0.3
240
+ llvmlite==0.43.0
241
+ locket==1.0.0
242
+ logical-unification==0.4.6
243
+ lxml==4.9.4
244
+ marisa-trie==1.2.0
245
+ Markdown==3.7
246
+ markdown-it-py==3.0.0
247
+ MarkupSafe==3.0.1
248
+ matplotlib==3.7.1
249
+ matplotlib-inline==0.1.7
250
+ matplotlib-venn==1.1.1
251
+ mdit-py-plugins==0.4.2
252
+ mdurl==0.1.2
253
+ miniKanren==1.0.3
254
+ missingno==0.5.2
255
+ mistune==0.8.4
256
+ mizani==0.11.4
257
+ mkl==2024.2.2
258
+ ml-dtypes==0.4.1
259
+ mlxtend==0.23.1
260
+ mmcv==2.2.0
261
+ mmengine==0.10.5
262
+ more-itertools==10.5.0
263
+ moviepy==1.0.3
264
+ mpmath==1.3.0
265
+ msgpack==1.0.8
266
+ multidict==6.1.0
267
+ multipledispatch==1.0.0
268
+ multitasking==0.0.11
269
+ murmurhash==1.0.10
270
+ music21==9.1.0
271
+ namex==0.0.8
272
+ natsort==8.4.0
273
+ nbclassic==1.1.0
274
+ nbclient==0.10.0
275
+ nbconvert==6.5.4
276
+ nbformat==5.10.4
277
+ nest-asyncio==1.6.0
278
+ networkx==3.4
279
+ nibabel==5.2.1
280
+ nltk==3.8.1
281
+ notebook==6.5.5
282
+ notebook_shim==0.2.4
283
+ numba==0.60.0
284
+ numexpr==2.10.1
285
+ numpy==1.26.4
286
+ nvidia-cublas-cu12==12.6.3.3
287
+ nvidia-cuda-cupti-cu12==12.6.80
288
+ nvidia-cuda-nvcc-cu12==12.6.77
289
+ nvidia-cuda-runtime-cu12==12.6.77
290
+ nvidia-cudnn-cu12==9.5.0.50
291
+ nvidia-cufft-cu12==11.3.0.4
292
+ nvidia-cusolver-cu12==11.7.1.2
293
+ nvidia-cusparse-cu12==12.5.4.2
294
+ nvidia-nccl-cu12==2.23.4
295
+ nvidia-nvjitlink-cu12==12.6.77
296
+ nvtx==0.2.10
297
+ oauth2client==4.1.3
298
+ oauthlib==3.2.2
299
+ opencv-contrib-python==4.10.0.84
300
+ opencv-python==4.10.0.84
301
+ opencv-python-headless==4.10.0.84
302
+ openpyxl==3.1.5
303
+ opentelemetry-api==1.16.0
304
+ opentelemetry-sdk==1.16.0
305
+ opentelemetry-semantic-conventions==0.37b0
306
+ opt_einsum==3.4.0
307
+ optax==0.2.3
308
+ optree==0.13.0
309
+ orbax-checkpoint==0.6.4
310
+ ordered-set==4.1.0
311
+ osqp==0.6.7.post0
312
+ packaging==24.1
313
+ pandas==2.2.2
314
+ pandas-datareader==0.10.0
315
+ pandas-gbq==0.23.2
316
+ pandas-stubs==2.2.2.240909
317
+ pandocfilters==1.5.1
318
+ panel==1.4.5
319
+ param==2.1.1
320
+ parso==0.8.4
321
+ parsy==2.1
322
+ partd==1.4.2
323
+ pathlib==1.0.1
324
+ patsy==0.5.6
325
+ peewee==3.17.6
326
+ pexpect==4.9.0
327
+ pickleshare==0.7.5
328
+ pillow==10.4.0
329
+ pip-tools==7.4.1
330
+ platformdirs==4.3.6
331
+ plotly==5.24.1
332
+ plotnine==0.13.6
333
+ pluggy==1.5.0
334
+ polars==1.7.1
335
+ pooch==1.8.2
336
+ portpicker==1.5.2
337
+ prefetch_generator==1.0.3
338
+ preshed==3.0.9
339
+ prettytable==3.11.0
340
+ proglog==0.1.10
341
+ progressbar2==4.5.0
342
+ prometheus_client==0.21.0
343
+ promise==2.3
344
+ prompt_toolkit==3.0.48
345
+ propcache==0.2.0
346
+ prophet==1.1.6
347
+ proto-plus==1.24.0
348
+ protobuf==3.20.3
349
+ psutil==5.9.5
350
+ psycopg2==2.9.9
351
+ ptyprocess==0.7.0
352
+ py-cpuinfo==9.0.0
353
+ py4j==0.10.9.7
354
+ pyarrow==16.1.0
355
+ pyarrow-hotfix==0.6
356
+ pyasn1==0.6.1
357
+ pyasn1_modules==0.4.1
358
+ pycocotools==2.0.8
359
+ pycparser==2.22
360
+ pydantic==2.9.2
361
+ pydantic_core==2.23.4
362
+ pydata-google-auth==1.8.2
363
+ pydot==3.0.2
364
+ pydot-ng==2.0.0
365
+ pydotplus==2.0.2
366
+ PyDrive==1.3.1
367
+ PyDrive2==1.20.0
368
+ pyerfa==2.0.1.4
369
+ pygame==2.6.1
370
+ Pygments==2.18.0
371
+ PyGObject==3.42.1
372
+ PyJWT==2.9.0
373
+ pymc==5.16.2
374
+ pymystem3==0.2.0
375
+ pynvjitlink-cu12==0.3.0
376
+ pyogrio==0.10.0
377
+ PyOpenGL==3.1.7
378
+ pyOpenSSL==24.2.1
379
+ pyparsing==3.1.4
380
+ pyperclip==1.9.0
381
+ pyproj==3.7.0
382
+ pyproject_hooks==1.2.0
383
+ pyshp==2.3.1
384
+ PySocks==1.7.1
385
+ pytensor==2.25.5
386
+ pytest==7.4.4
387
+ python-apt==0.0.0
388
+ python-box==7.2.0
389
+ python-dateutil==2.8.2
390
+ python-louvain==0.16
391
+ python-slugify==8.0.4
392
+ python-utils==3.9.0
393
+ pytz==2024.2
394
+ pyviz_comms==3.0.3
395
+ PyYAML==6.0.2
396
+ pyzmq==24.0.1
397
+ qdldl==0.1.7.post4
398
+ ratelim==0.1.6
399
+ referencing==0.35.1
400
+ regex==2024.9.11
401
+ requests==2.32.3
402
+ requests-oauthlib==1.3.1
403
+ requirements-parser==0.9.0
404
+ rich==13.9.2
405
+ rmm-cu12==24.6.0
406
+ rpds-py==0.20.0
407
+ rpy2==3.4.2
408
+ rsa==4.9
409
+ safetensors==0.4.5
410
+ scikit-image==0.24.0
411
+ scikit-learn==1.5.2
412
+ scipy==1.13.1
413
+ scooby==0.10.0
414
+ scs==3.2.7
415
+ seaborn==0.13.2
416
+ SecretStorage==3.3.1
417
+ Send2Trash==1.8.3
418
+ sentencepiece==0.2.0
419
+ shapely==2.0.6
420
+ shellingham==1.5.4
421
+ simple-parsing==0.1.6
422
+ six==1.16.0
423
+ sklearn-pandas==2.2.0
424
+ smart-open==7.0.5
425
+ sniffio==1.3.1
426
+ snowballstemmer==2.2.0
427
+ sortedcontainers==2.4.0
428
+ soundfile==0.12.1
429
+ soupsieve==2.6
430
+ soxr==0.5.0.post1
431
+ spacy==3.7.5
432
+ spacy-legacy==3.0.12
433
+ spacy-loggers==1.0.5
434
+ Sphinx==5.0.2
435
+ sphinxcontrib-applehelp==2.0.0
436
+ sphinxcontrib-devhelp==2.0.0
437
+ sphinxcontrib-htmlhelp==2.1.0
438
+ sphinxcontrib-jsmath==1.0.1
439
+ sphinxcontrib-qthelp==2.0.0
440
+ sphinxcontrib-serializinghtml==2.0.0
441
+ SQLAlchemy==2.0.35
442
+ sqlglot==25.1.0
443
+ sqlparse==0.5.1
444
+ srsly==2.4.8
445
+ ssh-import-id==5.11
446
+ stanio==0.5.1
447
+ statsmodels==0.14.4
448
+ StrEnum==0.4.15
449
+ sympy==1.13.3
450
+ tables==3.8.0
451
+ tabulate==0.9.0
452
+ tbb==2021.13.1
453
+ tblib==3.0.0
454
+ tenacity==9.0.0
455
+ tensorboard==2.17.0
456
+ tensorboard-data-server==0.7.2
457
+ tensorflow==2.17.0
458
+ tensorflow-datasets==4.9.6
459
+ tensorflow-hub==0.16.1
460
+ tensorflow-io-gcs-filesystem==0.37.1
461
+ tensorflow-metadata==1.16.1
462
+ tensorflow-probability==0.24.0
463
+ tensorstore==0.1.66
464
+ termcolor==2.5.0
465
+ terminado==0.18.1
466
+ text-unidecode==1.3
467
+ textblob==0.17.1
468
+ tf-slim==1.1.0
469
+ tf_keras==2.17.0
470
+ thinc==8.2.5
471
+ threadpoolctl==3.5.0
472
+ tifffile==2024.9.20
473
+ tinycss2==1.3.0
474
+ tokenizers==0.19.1
475
+ toml==0.10.2
476
+ tomli==2.0.2
477
+ toolz==0.12.1
478
+ torch @ https://download.pytorch.org/whl/cu121_full/torch-2.4.1%2Bcu121-cp310-cp310-linux_x86_64.whl
479
+ torchaudio @ https://download.pytorch.org/whl/cu121_full/torchaudio-2.4.1%2Bcu121-cp310-cp310-linux_x86_64.whl
480
+ torchsummary==1.5.1
481
+ torchvision @ https://download.pytorch.org/whl/cu121_full/torchvision-0.19.1%2Bcu121-cp310-cp310-linux_x86_64.whl
482
+ tornado==6.3.3
483
+ tqdm==4.66.5
484
+ traitlets==5.7.1
485
+ traittypes==0.2.1
486
+ transformers==4.44.2
487
+ tweepy==4.14.0
488
+ typeguard==4.3.0
489
+ typer==0.12.5
490
+ types-pytz==2024.2.0.20241003
491
+ types-setuptools==75.1.0.20241014
492
+ typing_extensions==4.12.2
493
+ tzdata==2024.2
494
+ tzlocal==5.2
495
+ uc-micro-py==1.0.3
496
+ uritemplate==4.1.1
497
+ urllib3==2.2.3
498
+ vega-datasets==0.9.0
499
+ wadllib==1.3.6
500
+ wasabi==1.1.3
501
+ wcwidth==0.2.13
502
+ weasel==0.4.1
503
+ webcolors==24.8.0
504
+ webencodings==0.5.1
505
+ websocket-client==1.8.0
506
+ Werkzeug==3.0.4
507
+ widgetsnbextension==3.6.9
508
+ wordcloud==1.9.3
509
+ wrapt==1.16.0
510
+ xarray==2024.9.0
511
+ xarray-einstats==0.8.0
512
+ xgboost==2.1.1
513
+ xlrd==2.0.1
514
+ xyzservices==2024.9.0
515
+ yapf==0.40.2
516
+ yarl==1.14.0
517
+ yellowbrick==1.5
518
+ yfinance==0.2.44
519
+ zict==3.0.0
520
+ zipp==3.20.2
text_net/DGRN.py CHANGED
@@ -3,7 +3,7 @@ import torch
3
  from .deform_conv import DCN_layer
4
  import clip
5
 
6
- clip_model, preprocess = clip.load("ViT-B/32", device='cuda')
7
 
8
  # 동적으로 텍스트 임베딩 차원 가져오기
9
  text_embed_dim = clip_model.text_projection.shape[1]
 
3
  from .deform_conv import DCN_layer
4
  import clip
5
 
6
+ clip_model, preprocess = clip.load("ViT-B/32", device='cpu')
7
 
8
  # 동적으로 텍스트 임베딩 차원 가져오기
9
  text_embed_dim = clip_model.text_projection.shape[1]
text_net/moco.py CHANGED
@@ -73,7 +73,7 @@ class MoCo(nn.Module):
73
  num_gpus = batch_size_all // batch_size_this
74
 
75
  # random shuffle index
76
- idx_shuffle = torch.randperm(batch_size_all).cuda()
77
 
78
  # broadcast to all gpus
79
  torch.distributed.broadcast(idx_shuffle, src=0)
@@ -140,7 +140,7 @@ class MoCo(nn.Module):
140
  logits /= self.T
141
 
142
  # labels: positive key indicators
143
- labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()
144
  # dequeue and enqueue
145
  self._dequeue_and_enqueue(k)
146
 
 
73
  num_gpus = batch_size_all // batch_size_this
74
 
75
  # random shuffle index
76
+ idx_shuffle = torch.randperm(batch_size_all).to('cpu')
77
 
78
  # broadcast to all gpus
79
  torch.distributed.broadcast(idx_shuffle, src=0)
 
140
  logits /= self.T
141
 
142
  # labels: positive key indicators
143
+ labels = torch.zeros(logits.shape[0], dtype=torch.long).to('cpu')
144
  # dequeue and enqueue
145
  self._dequeue_and_enqueue(k)
146