Shakir60 commited on
Commit
7030681
·
verified ·
1 Parent(s): 84b3e0e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -4
app.py CHANGED
@@ -34,13 +34,20 @@ def load_model():
34
  """Load and cache the model and processor"""
35
  try:
36
  model_name = "google/vit-base-patch16-224"
 
 
 
 
 
37
  model = ViTForImageClassification.from_pretrained(
38
  model_name,
39
  num_labels=len(DAMAGE_TYPES),
40
  ignore_mismatched_sizes=True,
41
- device_map="auto"
42
- )
43
- processor = ViTImageProcessor.from_pretrained(model_name)
 
 
44
  return model, processor
45
  except Exception as e:
46
  st.error(f"Error loading model: {str(e)}")
@@ -77,13 +84,16 @@ def preprocess_image(uploaded_file):
77
  def analyze_damage(image, model, processor):
78
  """Analyze structural damage in the image"""
79
  try:
 
80
  with torch.no_grad():
81
  image = image.convert('RGB')
82
  inputs = processor(images=image, return_tensors="pt")
 
 
83
  outputs = model(**inputs)
84
  probs = torch.nn.functional.softmax(outputs.logits, dim=1)[0]
85
  cleanup_memory()
86
- return probs
87
  except RuntimeError as e:
88
  if "out of memory" in str(e):
89
  cleanup_memory()
 
34
  """Load and cache the model and processor"""
35
  try:
36
  model_name = "google/vit-base-patch16-224"
37
+ # Initialize the processor first
38
+ processor = ViTImageProcessor.from_pretrained(model_name)
39
+
40
+ # Load model with specific device configuration
41
+ device = "cuda" if torch.cuda.is_available() else "cpu"
42
  model = ViTForImageClassification.from_pretrained(
43
  model_name,
44
  num_labels=len(DAMAGE_TYPES),
45
  ignore_mismatched_sizes=True,
46
+ ).to(device)
47
+
48
+ # Ensure model is in evaluation mode
49
+ model.eval()
50
+
51
  return model, processor
52
  except Exception as e:
53
  st.error(f"Error loading model: {str(e)}")
 
84
  def analyze_damage(image, model, processor):
85
  """Analyze structural damage in the image"""
86
  try:
87
+ device = next(model.parameters()).device
88
  with torch.no_grad():
89
  image = image.convert('RGB')
90
  inputs = processor(images=image, return_tensors="pt")
91
+ # Move inputs to the same device as model
92
+ inputs = {k: v.to(device) for k, v in inputs.items()}
93
  outputs = model(**inputs)
94
  probs = torch.nn.functional.softmax(outputs.logits, dim=1)[0]
95
  cleanup_memory()
96
+ return probs.cpu() # Move results back to CPU
97
  except RuntimeError as e:
98
  if "out of memory" in str(e):
99
  cleanup_memory()