Spaces:
Sleeping
Sleeping
Sontranwakumo
commited on
Commit
·
b5f3034
1
Parent(s):
5dfb339
fix: predict bugs
Browse files- .gitignore +1 -0
- app/services/predict.py +11 -7
.gitignore
CHANGED
@@ -110,3 +110,4 @@ venv.bak/
|
|
110 |
|
111 |
# FastAPI specific
|
112 |
.pytest_cache/
|
|
|
|
110 |
|
111 |
# FastAPI specific
|
112 |
.pytest_cache/
|
113 |
+
.DS_Store
|
app/services/predict.py
CHANGED
@@ -40,6 +40,9 @@ class PredictService:
|
|
40 |
if request.context.nodes is None:
|
41 |
request.context.nodes = []
|
42 |
request.context.nodes = request.context.nodes + additional_nodes
|
|
|
|
|
|
|
43 |
|
44 |
env_task = asyncio.create_task(
|
45 |
kg.get_disease_from_env_factors(request.context.crop_id, request.context.nodes)
|
@@ -48,10 +51,11 @@ class PredictService:
|
|
48 |
kg.get_disease_from_symptoms(request.context.crop_id, request.context.nodes)
|
49 |
)
|
50 |
|
51 |
-
|
52 |
context = request.context
|
53 |
-
context.nodes.extend([env_result["disease"] for env_result in
|
54 |
-
context.nodes.extend([symptom_result["disease"] for symptom_result in
|
|
|
55 |
context.nodes.sort(key=lambda x: x.score, reverse=True)
|
56 |
|
57 |
# Tính toán final_labels bằng trung bình có trọng số
|
@@ -59,15 +63,15 @@ class PredictService:
|
|
59 |
print("Got predicted labels")
|
60 |
context.final_labels = self.calculate_final_labels(
|
61 |
context.predicted_labels,
|
62 |
-
|
63 |
-
|
64 |
context.crop_id
|
65 |
)
|
66 |
|
67 |
return {
|
68 |
"context": context,
|
69 |
-
"
|
70 |
-
"
|
71 |
}
|
72 |
|
73 |
except Exception as e:
|
|
|
40 |
if request.context.nodes is None:
|
41 |
request.context.nodes = []
|
42 |
request.context.nodes = request.context.nodes + additional_nodes
|
43 |
+
for node in request.context.nodes:
|
44 |
+
if node.score is None:
|
45 |
+
node.score = 0.9
|
46 |
|
47 |
env_task = asyncio.create_task(
|
48 |
kg.get_disease_from_env_factors(request.context.crop_id, request.context.nodes)
|
|
|
51 |
kg.get_disease_from_symptoms(request.context.crop_id, request.context.nodes)
|
52 |
)
|
53 |
|
54 |
+
env_results, symptom_results = await asyncio.gather(env_task, symptom_task)
|
55 |
context = request.context
|
56 |
+
context.nodes.extend([env_result["disease"] for env_result in env_results])
|
57 |
+
context.nodes.extend([symptom_result["disease"] for symptom_result in symptom_results])
|
58 |
+
print(context.nodes)
|
59 |
context.nodes.sort(key=lambda x: x.score, reverse=True)
|
60 |
|
61 |
# Tính toán final_labels bằng trung bình có trọng số
|
|
|
63 |
print("Got predicted labels")
|
64 |
context.final_labels = self.calculate_final_labels(
|
65 |
context.predicted_labels,
|
66 |
+
env_results,
|
67 |
+
symptom_results,
|
68 |
context.crop_id
|
69 |
)
|
70 |
|
71 |
return {
|
72 |
"context": context,
|
73 |
+
"env_results": env_results,
|
74 |
+
"symptom_results": symptom_results
|
75 |
}
|
76 |
|
77 |
except Exception as e:
|