Ashrafb commited on
Commit
e72c238
·
verified ·
1 Parent(s): 05761be

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -14
app.py CHANGED
@@ -27,24 +27,19 @@ from PIL import Image
27
  from io import BytesIO
28
  import torchvision.transforms as T
29
  import requests
30
- from pathlib import Path
31
 
32
- MODEL_URL = "https://www.dropbox.com/s/04suaimdpru76h3/ArtLine_920.pkl?dl=1"
33
- MODEL_PATH = "ArtLine_920.pkl"
34
 
35
- # Download the model file
36
- response = requests.get(MODEL_URL)
37
- if response.status_code == 200:
38
- with open(MODEL_PATH, 'wb') as f:
39
- f.write(response.content)
40
- else:
41
- print("Failed to download the model")
42
 
43
- # Load the model using FastAI
44
- path = Path(".")
45
- learn = load_learner(path, MODEL_PATH)
46
 
47
- app = FastAPI()
48
 
49
  class FeatureLoss(nn.Module):
50
  def __init__(self, m_feat, layer_ids, layer_wgts):
 
27
  from io import BytesIO
28
  import torchvision.transforms as T
29
  import requests
30
+ import model_loader
31
 
32
+ app = FastAPI()
 
33
 
34
+ # Download the model if not already downloaded
35
+ MODEL_URL = "https://www.dropbox.com/s/04suaimdpru76h3/ArtLine_920.pkl?dl=1"
36
+ MODEL_FILENAME = "ArtLine_920.pkl"
37
+ if not os.path.exists(MODEL_FILENAME):
38
+ model_loader.download_model(MODEL_URL, MODEL_FILENAME)
 
 
39
 
40
+ # Load the model
41
+ learn = model_loader.load_model(MODEL_FILENAME)
 
42
 
 
43
 
44
  class FeatureLoss(nn.Module):
45
  def __init__(self, m_feat, layer_ids, layer_wgts):