Professor-Bean /
nprasad24's picture
faf18c9 verified
history blame
4.7 kB
import subprocess'pip install -r requirements.txt', shell = True)
import gradio as gr
import os
from PIL import Image
import numpy as np
from transformers import pipeline
import transformers
from langchain_community.document_loaders import TextLoader
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import CharacterTextSplitter
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_fireworks import ChatFireworks
from langchain_core.prompts import ChatPromptTemplate
from transformers import AutoModelForImageClassification, AutoImageProcessor
from langchain import HuggingFacePipeline
def image_to_query(image):
input: Image
function: Performs image classification using fine-tuned model
output: Query for the LLM
image = Image.fromarray(image)
model = AutoModelForImageClassification.from_pretrained("nprasad24/bean_classifier", from_tf=True)
image_processor = AutoImageProcessor.from_pretrained("nprasad24/bean_classifier")
classifier = pipeline("image-classification", model=model, image_processor=image_processor)
scores = classifier(image)
# Get the dictionary with the maximum score
max_score_dict = max(scores, key=lambda x: x['score'])
# Extract the label with the maximum score
label_with_max_score = max_score_dict['label']
# script to check if the image uploaded is indeed a leaf or not
counter = 0
for ele in scores:
if 0.2 <= ele['score'] <= 0.4:
counter += 1
if label_with_max_score == 'healthy' and counter != 3:
query = "The plant is healthy. Give tips on maintaining the plant"
elif label_with_max_score == 'bean_rust' and counter != 3:
query = "The detected disease is bean rust. Explain the disease"
elif label_with_max_score == 'angular_leaf_spot' and counter != 3:
query = "The detected disease is angular leaf spot. Explain the disease"
query = "Given image is not of a plant."
return query
def ragChain():
function: creates a rag chain
output: rag chain
loader = TextLoader("knowledgeBase.txt")
docs = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
docs = text_splitter.split_documents(docs)
vectorstore = vectorstore = FAISS.load_local("faiss_index", embeddings = HuggingFaceEmbeddings(), allow_dangerous_deserialization = True)
retriever = vectorstore.as_retriever(search_type = "similarity", search_kwargs = {"k": 5})
api_key = os.getenv("APIKEY")
llm = ChatFireworks(model="accounts/fireworks/models/mixtral-8x7b-instruct", api_key = api_key)
prompt = ChatPromptTemplate.from_messages(
"""You are a knowledgeable agricultural assistant. If a disease is detected, you have to give information on the disease.
If the image is not of a plant, ask human to upload image of a plant and stop generating any response.
If the plant is healthy, just give maintenance tips. """
"""Provide information about the leaf disease in question in bullet points.
Start your answer by mentioning the disease (if any) or healthy in this format: 'Condition: disease name'.
("human", "{context}, {question}"),
rag_chain = (
"context": retriever,
"question": RunnablePassthrough()
| prompt
| llm
| StrOutputParser()
return rag_chain
def generate_response(rag_chain, query):
input: rag chain, query
function: generates response using llm and knowledge base
output: generated response by the llm
return rag_chain.invoke(f"{query}")
def main(image):
query = image_to_query(image)
chain = ragChain()
output = generate_response(chain, query)
return output
title = "Professor Bean: The Bean Disease Expert"
description = "Professor Bean is an agricultural expert. He will guide you on how to protect your plants from bean diseases"
app = gr.Interface(fn=main,
examples=[["sampleImages/sample1.jpg"], ["sampleImages/sample2.jpg"],["sampleImages/sample3.jpg"], ["sampleImages/sample4.jpeg"]]