lyimo commited on
Commit
e31feff
·
verified ·
1 Parent(s): e002c60

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +137 -0
app.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ from fastai.vision.all import *
4
+ from groq import Groq
5
+ from PIL import Image
6
+
7
+ # Load the trained model
8
+ learn = load_learner('export.pkl')
9
+ labels = learn.dls.vocab
10
+
11
+ # Initialize Groq client
12
+ client = Groq(
13
+ api_key=os.environ.get("GROQ_API_KEY"),
14
+ )
15
+
16
+ def get_bird_info(bird_name):
17
+ """Get detailed information about a bird using Groq API"""
18
+ prompt = f"""
19
+ Provide detailed information about the {bird_name} bird, including:
20
+ 1. Physical characteristics and appearance
21
+ 2. Habitat and distribution
22
+ 3. Diet and behavior
23
+ 4. Migration patterns (emphasize if this pattern has changed in recent years due to climate change)
24
+ 5. Conservation status
25
+
26
+ If this bird is not commonly found in Tanzania, explicitly flag that this bird is "NOT TYPICALLY FOUND IN TANZANIA" at the beginning of your response and explain why its presence might be unusual.
27
+
28
+ Format your response in markdown for better readability.
29
+ """
30
+
31
+ try:
32
+ chat_completion = client.chat.completions.create(
33
+ messages=[
34
+ {
35
+ "role": "user",
36
+ "content": prompt,
37
+ }
38
+ ],
39
+ model="llama-3.3-70b-versatile",
40
+ )
41
+ return chat_completion.choices[0].message.content
42
+ except Exception as e:
43
+ return f"Error fetching information: {str(e)}"
44
+
45
+ def predict_and_get_info(img):
46
+ """Predict bird species and get detailed information"""
47
+ # Process the image
48
+ img = PILImage.create(img)
49
+
50
+ # Get prediction
51
+ pred, pred_idx, probs = learn.predict(img)
52
+
53
+ # Format prediction results
54
+ prediction_results = {labels[i]: float(probs[i]) for i in range(len(labels))}
55
+
56
+ # Get top prediction
57
+ top_bird = str(pred)
58
+
59
+ # Get detailed information about the top predicted bird
60
+ bird_info = get_bird_info(top_bird)
61
+
62
+ return prediction_results, bird_info
63
+
64
+ def follow_up_question(question, bird_name):
65
+ """Allow researchers to ask follow-up questions about the identified bird"""
66
+ prompt = f"""
67
+ The researcher is asking about the {bird_name} bird: "{question}"
68
+
69
+ Provide a detailed, scientific answer focusing on accurate ornithological information.
70
+ If the question relates to Tanzania or climate change impacts, emphasize those aspects in your response.
71
+ Format your response in markdown for better readability.
72
+ """
73
+
74
+ try:
75
+ chat_completion = client.chat.completions.create(
76
+ messages=[
77
+ {
78
+ "role": "user",
79
+ "content": prompt,
80
+ }
81
+ ],
82
+ model="llama-3.3-70b-versatile",
83
+ )
84
+ return chat_completion.choices[0].message.content
85
+ except Exception as e:
86
+ return f"Error fetching information: {str(e)}"
87
+
88
+ # Create the Gradio interface
89
+ with gr.Blocks(theme=gr.themes.Soft()) as app:
90
+ gr.Markdown("# Bird Species Identification for Researchers")
91
+ gr.Markdown("Upload an image to identify bird species and get detailed information relevant to research in Tanzania and climate change studies.")
92
+
93
+ with gr.Row():
94
+ with gr.Column(scale=1):
95
+ input_image = gr.Image(type="pil", label="Upload Bird Image")
96
+ submit_btn = gr.Button("Identify Bird", variant="primary")
97
+
98
+ with gr.Column(scale=2):
99
+ with gr.Row():
100
+ prediction_output = gr.Label(label="Prediction Results")
101
+ with gr.Row():
102
+ bird_info_output = gr.Markdown(label="Bird Information")
103
+
104
+ # For follow-up questions
105
+ with gr.Row():
106
+ gr.Markdown("## Ask Follow-up Questions")
107
+
108
+ with gr.Row():
109
+ with gr.Column(scale=1):
110
+ current_bird = gr.State("")
111
+ follow_up_input = gr.Textbox(label="Your Question", placeholder="Ask more about this bird species...")
112
+ follow_up_btn = gr.Button("Submit Question")
113
+
114
+ with gr.Column(scale=2):
115
+ follow_up_output = gr.Markdown(label="Answer")
116
+
117
+ # Set up event handlers
118
+ def process_image(img):
119
+ pred_results, info = predict_and_get_info(img)
120
+ # Extract bird name from top prediction
121
+ bird_name = max(pred_results.items(), key=lambda x: x[1])[0]
122
+ return pred_results, info, bird_name
123
+
124
+ submit_btn.click(
125
+ process_image,
126
+ inputs=[input_image],
127
+ outputs=[prediction_output, bird_info_output, current_bird]
128
+ )
129
+
130
+ follow_up_btn.click(
131
+ follow_up_question,
132
+ inputs=[follow_up_input, current_bird],
133
+ outputs=[follow_up_output]
134
+ )
135
+
136
+ # Launch the app
137
+ app.launch(share=True)