sunshangquan
commited on
Commit
·
5a7c1cb
1
Parent(s):
3274b48
basicsr/models/archs/__pycache__/histoformer_arch.cpython-38.pyc
CHANGED
Binary files a/basicsr/models/archs/__pycache__/histoformer_arch.cpython-38.pyc and b/basicsr/models/archs/__pycache__/histoformer_arch.cpython-38.pyc differ
|
|
basicsr/models/archs/histoformer_arch.py
CHANGED
@@ -5,7 +5,7 @@ from pdb import set_trace as stx
|
|
5 |
import numbers
|
6 |
|
7 |
from einops import rearrange
|
8 |
-
|
9 |
#########################################################################
|
10 |
|
11 |
Conv2d = nn.Conv2d
|
@@ -267,7 +267,7 @@ class Upsample(nn.Module):
|
|
267 |
return self.body(x)
|
268 |
|
269 |
##########################################################################
|
270 |
-
class Histoformer(nn.Module):
|
271 |
def __init__(self,
|
272 |
inp_channels=3,
|
273 |
out_channels=3,
|
|
|
5 |
import numbers
|
6 |
|
7 |
from einops import rearrange
|
8 |
+
from huggingface_hub import PyTorchModelHubMixin
|
9 |
#########################################################################
|
10 |
|
11 |
Conv2d = nn.Conv2d
|
|
|
267 |
return self.body(x)
|
268 |
|
269 |
##########################################################################
|
270 |
+
class Histoformer(nn.Module, PyTorchModelHubMixin):
|
271 |
def __init__(self,
|
272 |
inp_channels=3,
|
273 |
out_channels=3,
|
demo.py
CHANGED
@@ -1,12 +1,14 @@
|
|
|
|
1 |
from basicsr.models.archs.histoformer_arch import Histoformer
|
2 |
-
|
3 |
from huggingface_hub import hf_hub_download
|
4 |
# define model
|
5 |
model = Histoformer()
|
6 |
|
7 |
# equip with weights
|
8 |
filepath = hf_hub_download(repo_id="sunsean/Histoformer", filename="Allweather/pretrained_models/net_g_real.pth")
|
9 |
-
|
|
|
10 |
|
11 |
# push to hub
|
12 |
model.push_to_hub("sunsean/Histoformer-real")
|
|
|
1 |
+
import torch
|
2 |
from basicsr.models.archs.histoformer_arch import Histoformer
|
3 |
+
from huggingface_hub import PyTorchModelHubMixin
|
4 |
from huggingface_hub import hf_hub_download
|
5 |
# define model
|
6 |
model = Histoformer()
|
7 |
|
8 |
# equip with weights
|
9 |
filepath = hf_hub_download(repo_id="sunsean/Histoformer", filename="Allweather/pretrained_models/net_g_real.pth")
|
10 |
+
print(filepath)
|
11 |
+
model.load_state_dict(torch.load(filepath, map_location='cpu')['params'], )
|
12 |
|
13 |
# push to hub
|
14 |
model.push_to_hub("sunsean/Histoformer-real")
|