devdata commited on
Commit
c5b6e4e
·
1 Parent(s): 371dc1b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -0
app.py CHANGED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from fastai.vision.all import *
3
+ import skimage
4
+ import openai
5
+
6
+ openai.api_key = os.getenv("OPENAI_API_KEY")
7
+
8
+ # Load the model
9
+ learn = load_learner('model.pkl')
10
+
11
+ # Define the labels
12
+ labels = learn.dls.vocab
13
+
14
+ # Define a function for generating text
15
+ def generate_text(prompt):
16
+ response = openai.Completion.create(
17
+ engine="davinci",
18
+ prompt=prompt,
19
+ max_tokens=1024,
20
+ n=1,
21
+ stop=None,
22
+ temperature=0.7,
23
+ )
24
+ return response.choices[0].text.strip()
25
+
26
+ # Define a function to handle user queries
27
+ def handle_query(query, chat_history):
28
+ response = openai.ChatCompletion.create(
29
+ model="gpt-3.5-turbo",
30
+ messages=[{"role": "system", "content": "You are a helpful assistant."},
31
+ {"role": "user", "content": query}] + chat_history
32
+ )
33
+ return response.choices[0].message['content']
34
+
35
+ # Define the prediction function
36
+ def predict(img):
37
+ img = PILImage.create(img)
38
+ pred,pred_idx,probs = learn.predict(img)
39
+ prediction = {labels[i]: float(probs[i]) for i in range(len(labels))}
40
+ chat_prompt = f"The model predicted {prediction}."
41
+ chat_response = generate_text(chat_prompt)
42
+ return {**prediction, 'chat_response': chat_response}
43
+
44
+ # Define the chat function
45
+ def chat(query, chat_history):
46
+ chat_response = handle_query(query, chat_history)
47
+ return chat_response
48
+
49
+ # Define the examples
50
+ examples = ['image.jpg']
51
+
52
+ # Define the interpretation
53
+ interpretation='default'
54
+
55
+ # Define the enable_queue
56
+ enable_queue=True
57
+
58
+ # Launch the interface
59
+ gr.Interface(fn=predict,inputs=gr.inputs.Image(shape=(512, 512)),outputs=gr.outputs.Label(num_top_classes=3),examples=examples,interpretation=interpretation,enable_queue=enable_queue).launch()