savitha7 commited on
Commit
8096272
1 Parent(s): 11a3d6f

update app

Browse files
Files changed (1) hide show
  1. app.py +59 -7
app.py CHANGED
@@ -9,27 +9,79 @@ processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
9
  # Define the BMI classes
10
  bmi_classes = ["underweight", "normal weight", "overweight", "obesity"]
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  def predict_bmi(image):
13
  # Prepare the inputs
14
  inputs = processor(text=bmi_classes, images=image, return_tensors="pt", padding=True)
15
  # Get model outputs
16
  outputs = model(**inputs)
17
- logits_per_image = outputs.logits_per_image # Image-text similarity scores
18
- probs = logits_per_image.softmax(dim=1) # Convert to probabilities
19
 
20
  # Find the class with the highest probability
21
  max_prob_index = probs.argmax().item()
22
  predicted_bmi_class = bmi_classes[max_prob_index]
23
 
24
- return predicted_bmi_class
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
- # Create Gradio interface
27
  interface = gr.Interface(
28
  fn=predict_bmi,
29
- inputs=gr.inputs.Image(type="pil"),
30
- outputs="text",
31
  title="BMI Prediction",
32
- description="Upload an image to predict BMI category (underweight, normal weight, overweight, obesity)."
33
  )
34
 
35
  # Launch the interface
 
9
  # Define the BMI classes
10
  bmi_classes = ["underweight", "normal weight", "overweight", "obesity"]
11
 
12
+ # Define the secondary models for BMI predictions and midpoints
13
+ bmi_ranges = {
14
+ "underweight": {
15
+ "BMI < 16.0": (0, 16.0, 8.0),
16
+ "16.0 ≤ BMI ≤ 16.99": (16.0, 16.99, 16.5),
17
+ "17.0 ≤ BMI ≤ 18.49": (17.0, 18.49, 17.75)
18
+ },
19
+ "normal weight": {
20
+ "18.5 ≤ BMI ≤ 20.4": (18.5, 20.4, 19.45),
21
+ "20.5 ≤ BMI ≤ 22.4": (20.5, 22.4, 21.45),
22
+ "22.5 ≤ BMI ≤ 24.9": (22.5, 24.9, 23.7)
23
+ },
24
+ "overweight": {
25
+ "25.0 ≤ BMI ≤ 26.9": (25.0, 26.9, 25.95),
26
+ "27.0 ≤ BMI ≤ 28.9": (27.0, 28.9, 27.95),
27
+ "29.0 ≤ BMI ≤ 29.9": (29.0, 29.9, 29.45)
28
+ },
29
+ "obesity": {
30
+ "30.0 ≤ BMI ≤ 34.9": (30.0, 34.9, 32.5),
31
+ "35.0 ≤ BMI ≤ 39.9": (35.0, 39.9, 37.45),
32
+ "BMI ≥ 40.0": (40.0, 100, 40.0) # Assuming 100 as the upper limit for BMI
33
+ }
34
+ }
35
+
36
  def predict_bmi(image):
37
  # Prepare the inputs
38
  inputs = processor(text=bmi_classes, images=image, return_tensors="pt", padding=True)
39
  # Get model outputs
40
  outputs = model(**inputs)
41
+ logits_per_image = outputs.logits_per_image # Image-text similarity scores
42
+ probs = logits_per_image.softmax(dim=1) # Convert to probabilities
43
 
44
  # Find the class with the highest probability
45
  max_prob_index = probs.argmax().item()
46
  predicted_bmi_class = bmi_classes[max_prob_index]
47
 
48
+ # Use the midpoint BMI for the predicted class
49
+ bmi_prediction = get_midpoint_bmi(predicted_bmi_class)
50
+
51
+ # Assume height is input by the user or extracted somehow
52
+ height_in_inches = 75 # Example height; replace with actual input or extraction
53
+ predicted_weight = calculate_weight(bmi_prediction, height_in_inches)
54
+
55
+ # Create the JSON output
56
+ result = {
57
+ "weightCategory": f"{predicted_bmi_class} - {bmi_prediction}",
58
+ "bmiPrediction": f"{bmi_prediction:.2f}",
59
+ "height": str(height_in_inches),
60
+ "predictedWeight": f"{predicted_weight:.2f} lbs"
61
+ }
62
+
63
+ return result
64
+
65
+ def get_midpoint_bmi(weight_category):
66
+ """Return the midpoint BMI for the given weight category."""
67
+ category_ranges = bmi_ranges.get(weight_category.lower())
68
+ for range_label, (low, high, mid) in category_ranges.items():
69
+ return mid # Return the first midpoint found in the given range
70
+
71
+ def calculate_weight(bmi, height_in_inches):
72
+ """Calculate the weight from BMI and height (in inches)."""
73
+ height_in_meters = height_in_inches * 0.0254 # Convert height to meters
74
+ weight_kg = bmi * (height_in_meters ** 2) # BMI formula to calculate weight
75
+ weight_lbs = weight_kg * 2.20462 # Convert kg to lbs
76
+ return weight_lbs
77
 
78
+ # Create Gradio interface with updated components
79
  interface = gr.Interface(
80
  fn=predict_bmi,
81
+ inputs=gr.Image(type="pil"),
82
+ outputs="json",
83
  title="BMI Prediction",
84
+ description="Upload an image to predict BMI category (underweight, normal weight, overweight, obesity) and receive a detailed prediction."
85
  )
86
 
87
  # Launch the interface