Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gradio as gr
|
3 |
+
from fastai.vision.all import *
|
4 |
+
from groq import Groq
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
# Load the trained model
|
8 |
+
learn = load_learner('export.pkl')
|
9 |
+
labels = learn.dls.vocab
|
10 |
+
|
11 |
+
# Initialize Groq client
|
12 |
+
client = Groq(
|
13 |
+
api_key=os.environ.get("GROQ_API_KEY"),
|
14 |
+
)
|
15 |
+
|
16 |
+
def get_bird_info(bird_name):
|
17 |
+
"""Get detailed information about a bird using Groq API"""
|
18 |
+
prompt = f"""
|
19 |
+
Provide detailed information about the {bird_name} bird, including:
|
20 |
+
1. Physical characteristics and appearance
|
21 |
+
2. Habitat and distribution
|
22 |
+
3. Diet and behavior
|
23 |
+
4. Migration patterns (emphasize if this pattern has changed in recent years due to climate change)
|
24 |
+
5. Conservation status
|
25 |
+
|
26 |
+
If this bird is not commonly found in Tanzania, explicitly flag that this bird is "NOT TYPICALLY FOUND IN TANZANIA" at the beginning of your response and explain why its presence might be unusual.
|
27 |
+
|
28 |
+
Format your response in markdown for better readability.
|
29 |
+
"""
|
30 |
+
|
31 |
+
try:
|
32 |
+
chat_completion = client.chat.completions.create(
|
33 |
+
messages=[
|
34 |
+
{
|
35 |
+
"role": "user",
|
36 |
+
"content": prompt,
|
37 |
+
}
|
38 |
+
],
|
39 |
+
model="llama-3.3-70b-versatile",
|
40 |
+
)
|
41 |
+
return chat_completion.choices[0].message.content
|
42 |
+
except Exception as e:
|
43 |
+
return f"Error fetching information: {str(e)}"
|
44 |
+
|
45 |
+
def predict_and_get_info(img):
|
46 |
+
"""Predict bird species and get detailed information"""
|
47 |
+
# Process the image
|
48 |
+
img = PILImage.create(img)
|
49 |
+
|
50 |
+
# Get prediction
|
51 |
+
pred, pred_idx, probs = learn.predict(img)
|
52 |
+
|
53 |
+
# Format prediction results
|
54 |
+
prediction_results = {labels[i]: float(probs[i]) for i in range(len(labels))}
|
55 |
+
|
56 |
+
# Get top prediction
|
57 |
+
top_bird = str(pred)
|
58 |
+
|
59 |
+
# Get detailed information about the top predicted bird
|
60 |
+
bird_info = get_bird_info(top_bird)
|
61 |
+
|
62 |
+
return prediction_results, bird_info
|
63 |
+
|
64 |
+
def follow_up_question(question, bird_name):
|
65 |
+
"""Allow researchers to ask follow-up questions about the identified bird"""
|
66 |
+
prompt = f"""
|
67 |
+
The researcher is asking about the {bird_name} bird: "{question}"
|
68 |
+
|
69 |
+
Provide a detailed, scientific answer focusing on accurate ornithological information.
|
70 |
+
If the question relates to Tanzania or climate change impacts, emphasize those aspects in your response.
|
71 |
+
Format your response in markdown for better readability.
|
72 |
+
"""
|
73 |
+
|
74 |
+
try:
|
75 |
+
chat_completion = client.chat.completions.create(
|
76 |
+
messages=[
|
77 |
+
{
|
78 |
+
"role": "user",
|
79 |
+
"content": prompt,
|
80 |
+
}
|
81 |
+
],
|
82 |
+
model="llama-3.3-70b-versatile",
|
83 |
+
)
|
84 |
+
return chat_completion.choices[0].message.content
|
85 |
+
except Exception as e:
|
86 |
+
return f"Error fetching information: {str(e)}"
|
87 |
+
|
88 |
+
# Create the Gradio interface
|
89 |
+
with gr.Blocks(theme=gr.themes.Soft()) as app:
|
90 |
+
gr.Markdown("# Bird Species Identification for Researchers")
|
91 |
+
gr.Markdown("Upload an image to identify bird species and get detailed information relevant to research in Tanzania and climate change studies.")
|
92 |
+
|
93 |
+
with gr.Row():
|
94 |
+
with gr.Column(scale=1):
|
95 |
+
input_image = gr.Image(type="pil", label="Upload Bird Image")
|
96 |
+
submit_btn = gr.Button("Identify Bird", variant="primary")
|
97 |
+
|
98 |
+
with gr.Column(scale=2):
|
99 |
+
with gr.Row():
|
100 |
+
prediction_output = gr.Label(label="Prediction Results")
|
101 |
+
with gr.Row():
|
102 |
+
bird_info_output = gr.Markdown(label="Bird Information")
|
103 |
+
|
104 |
+
# For follow-up questions
|
105 |
+
with gr.Row():
|
106 |
+
gr.Markdown("## Ask Follow-up Questions")
|
107 |
+
|
108 |
+
with gr.Row():
|
109 |
+
with gr.Column(scale=1):
|
110 |
+
current_bird = gr.State("")
|
111 |
+
follow_up_input = gr.Textbox(label="Your Question", placeholder="Ask more about this bird species...")
|
112 |
+
follow_up_btn = gr.Button("Submit Question")
|
113 |
+
|
114 |
+
with gr.Column(scale=2):
|
115 |
+
follow_up_output = gr.Markdown(label="Answer")
|
116 |
+
|
117 |
+
# Set up event handlers
|
118 |
+
def process_image(img):
|
119 |
+
pred_results, info = predict_and_get_info(img)
|
120 |
+
# Extract bird name from top prediction
|
121 |
+
bird_name = max(pred_results.items(), key=lambda x: x[1])[0]
|
122 |
+
return pred_results, info, bird_name
|
123 |
+
|
124 |
+
submit_btn.click(
|
125 |
+
process_image,
|
126 |
+
inputs=[input_image],
|
127 |
+
outputs=[prediction_output, bird_info_output, current_bird]
|
128 |
+
)
|
129 |
+
|
130 |
+
follow_up_btn.click(
|
131 |
+
follow_up_question,
|
132 |
+
inputs=[follow_up_input, current_bird],
|
133 |
+
outputs=[follow_up_output]
|
134 |
+
)
|
135 |
+
|
136 |
+
# Launch the app
|
137 |
+
app.launch(share=True)
|