Files changed (1) hide show
  1. app.py +58 -0
app.py CHANGED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from fastai.vision.all import *
3
+ import skimage
4
+ import openai
5
+ openai.api_key = os.getenv("OPENAI_API_KEY")
6
+
7
+ # Load the model
8
+ learn = load_learner('model.pkl')
9
+
10
+ # Define the labels
11
+ labels = learn.dls.vocab
12
+
13
+ # Define a function for generating text
14
+ def generate_text(prompt):
15
+ response = openai.Completion.create(
16
+ engine="davinci",
17
+ prompt=prompt,
18
+ max_tokens=1024,
19
+ n=1,
20
+ stop=None,
21
+ temperature=0.7,
22
+ )
23
+ return response.choices[0].text.strip()
24
+
25
+ # Define a function to handle user queries
26
+ def handle_query(query, chat_history):
27
+ response = openai.ChatCompletion.create(
28
+ model="gpt-3.5-turbo",
29
+ messages=[{"role": "system", "content": "You are a helpful assistant."},
30
+ {"role": "user", "content": query}] + chat_history
31
+ )
32
+ return response.choices[0].message['content']
33
+
34
+ # Define the prediction function
35
+ def predict(img):
36
+ img = PILImage.create(img)
37
+ pred,pred_idx,probs = learn.predict(img)
38
+ prediction = {labels[i]: float(probs[i]) for i in range(len(labels))}
39
+ chat_prompt = f"The model predicted {prediction}."
40
+ chat_response = generate_text(chat_prompt)
41
+ return {**prediction, 'chat_response': chat_response}
42
+
43
+ # Define the chat function
44
+ def chat(query, chat_history):
45
+ chat_response = handle_query(query, chat_history)
46
+ return chat_response
47
+
48
+ # Define the examples
49
+ examples = ['image.jpg']
50
+
51
+ # Define the interpretation
52
+ interpretation='default'
53
+
54
+ # Define the enable_queue
55
+ enable_queue=True
56
+
57
+ # Launch the interface
58
+ gr.Interface(fn=predict,inputs=gr.inputs.Image(shape=(512, 512)),outputs=gr.outputs.Label(num_top_classes=3),examples=examples,interpretation=interpretation,enable_queue=enable_queue).launch()