chat_vision / app.py
devdata's picture
Update app.py
97f9bd6
raw
history blame
1.67 kB
import gradio as gr
from fastai.vision.all import *
import openai
import os
openai.api_key = os.getenv("OPENAI_API_KEY")
# Load your trained model (you should replace 'model.pkl' with the path to your model file)
learn = load_learner('model.pkl')
# Define the labels for the output
labels = learn.dls.vocab
# Define the prediction function
def predict(img):
img = PILImage.create(img)
pred, pred_idx, probs = learn.predict(img)
prediction = {labels[i]: float(probs[i]) for i in range(len(labels))}
# Now generate a chat/text response based on the model's prediction.
chat_prompt = f"The image likely depicts the following: {pred}. What can I help you with next?"
# Ensure that you have set the OPENAI_API_KEY environment variable,
# as we will use it to interact with OpenAI's GPT-3 model.
response = openai.Completion.create(
engine="text-davinci-003", # Adjust the engine as needed for your use-case
prompt=chat_prompt,
max_tokens=1024,
n=1,
stop=None,
temperature=0.7,
)
text_response = response.choices[0].text.strip()
return prediction, text_response
# Create examples list by specifying the paths to the example images
examples = ["path/to/example1.jpg", "path/to/example2.jpg"] # replace with actual image paths
# Define the Gradio interface
iface = gr.Interface(
fn=predict,
inputs=gr.Image(shape=(512, 512)),
outputs=[gr.Label(num_top_classes=3), gr.Textbox(label="GPT-3 Response")],
examples=examples,
enable_queue=True # This is optional and only necessary if you're hosting under heavy traffic
)
# Launch the Gradio app
iface.launch()