Sontranwakumo commited on
Commit
b5f3034
·
1 Parent(s): 5dfb339

fix: predict bugs

Browse files
Files changed (2) hide show
  1. .gitignore +1 -0
  2. app/services/predict.py +11 -7
.gitignore CHANGED
@@ -110,3 +110,4 @@ venv.bak/
110
 
111
  # FastAPI specific
112
  .pytest_cache/
 
 
110
 
111
  # FastAPI specific
112
  .pytest_cache/
113
+ .DS_Store
app/services/predict.py CHANGED
@@ -40,6 +40,9 @@ class PredictService:
40
  if request.context.nodes is None:
41
  request.context.nodes = []
42
  request.context.nodes = request.context.nodes + additional_nodes
 
 
 
43
 
44
  env_task = asyncio.create_task(
45
  kg.get_disease_from_env_factors(request.context.crop_id, request.context.nodes)
@@ -48,10 +51,11 @@ class PredictService:
48
  kg.get_disease_from_symptoms(request.context.crop_id, request.context.nodes)
49
  )
50
 
51
- env_result, symptom_result = await asyncio.gather(env_task, symptom_task)
52
  context = request.context
53
- context.nodes.extend([env_result["disease"] for env_result in env_result])
54
- context.nodes.extend([symptom_result["disease"] for symptom_result in symptom_result])
 
55
  context.nodes.sort(key=lambda x: x.score, reverse=True)
56
 
57
  # Tính toán final_labels bằng trung bình có trọng số
@@ -59,15 +63,15 @@ class PredictService:
59
  print("Got predicted labels")
60
  context.final_labels = self.calculate_final_labels(
61
  context.predicted_labels,
62
- env_result,
63
- symptom_result,
64
  context.crop_id
65
  )
66
 
67
  return {
68
  "context": context,
69
- "env_result": env_result,
70
- "symptom_result": symptom_result
71
  }
72
 
73
  except Exception as e:
 
40
  if request.context.nodes is None:
41
  request.context.nodes = []
42
  request.context.nodes = request.context.nodes + additional_nodes
43
+ for node in request.context.nodes:
44
+ if node.score is None:
45
+ node.score = 0.9
46
 
47
  env_task = asyncio.create_task(
48
  kg.get_disease_from_env_factors(request.context.crop_id, request.context.nodes)
 
51
  kg.get_disease_from_symptoms(request.context.crop_id, request.context.nodes)
52
  )
53
 
54
+ env_results, symptom_results = await asyncio.gather(env_task, symptom_task)
55
  context = request.context
56
+ context.nodes.extend([env_result["disease"] for env_result in env_results])
57
+ context.nodes.extend([symptom_result["disease"] for symptom_result in symptom_results])
58
+ print(context.nodes)
59
  context.nodes.sort(key=lambda x: x.score, reverse=True)
60
 
61
  # Tính toán final_labels bằng trung bình có trọng số
 
63
  print("Got predicted labels")
64
  context.final_labels = self.calculate_final_labels(
65
  context.predicted_labels,
66
+ env_results,
67
+ symptom_results,
68
  context.crop_id
69
  )
70
 
71
  return {
72
  "context": context,
73
+ "env_results": env_results,
74
+ "symptom_results": symptom_results
75
  }
76
 
77
  except Exception as e: