File size: 4,701 Bytes
bfe7e73
faf18c9
bfe7e73
 
 
 
 
 
 
 
 
 
 
 
 
8eb9295
bfe7e73
8011378
b8c40a7
bfe7e73
 
5cfc086
bfe7e73
 
 
 
 
 
 
ed69b0e
bfe7e73
5cfc086
 
 
 
 
 
bfe7e73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4bc3117
 
bfe7e73
4bc3117
 
bfe7e73
bf6f587
bfe7e73
 
19c20c1
 
bfe7e73
 
 
 
 
5cfc086
 
bfe7e73
 
 
 
6cffbb4
bfe7e73
 
 
6cffbb4
bfe7e73
 
 
 
 
6cffbb4
bfe7e73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc00138
bfe7e73
c4c008b
 
bfe7e73
 
 
 
41eed49
bfe7e73
0caec8b
 
 
 
 
7fda5c2
0caec8b
bc00138
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import subprocess
#subprocess.run('pip install -r requirements.txt', shell = True)

import gradio as gr
import os
from PIL import Image
import numpy as np
from transformers import pipeline
import transformers
from langchain_community.document_loaders import TextLoader
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import CharacterTextSplitter
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_fireworks import ChatFireworks
from langchain_core.prompts import ChatPromptTemplate
from transformers import AutoModelForImageClassification, AutoImageProcessor
from langchain import HuggingFacePipeline


def image_to_query(image):
    """
    input: Image

    function: Performs image classification using fine-tuned model

    output: Query for the LLM
    """
    image =  Image.fromarray(image)

    model = AutoModelForImageClassification.from_pretrained("nprasad24/bean_classifier", from_tf=True)
    image_processor = AutoImageProcessor.from_pretrained("nprasad24/bean_classifier")

    classifier = pipeline("image-classification", model=model, image_processor=image_processor)

    scores = classifier(image)

    # Get the dictionary with the maximum score
    max_score_dict = max(scores, key=lambda x: x['score'])

    # Extract the label with the maximum score
    label_with_max_score = max_score_dict['label']

    # script to check if the image uploaded is indeed a leaf or not
    counter = 0
    for ele in scores:
        if 0.2 <= ele['score'] <= 0.4:
            counter += 1

    if label_with_max_score == 'healthy' and counter != 3:
        query = "The plant is healthy. Give tips on maintaining the plant"
    elif label_with_max_score == 'bean_rust' and counter != 3:
        query = "The detected disease is bean rust. Explain the disease"
    elif label_with_max_score == 'angular_leaf_spot' and counter != 3:
        query = "The detected disease is angular leaf spot. Explain the disease"
    else:
        query = "Given image is not of a plant."

    return query

def ragChain():
    """
    function: creates a rag chain

    output: rag chain
    """
    loader = TextLoader("knowledgeBase.txt")
    docs = loader.load()

    text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
    docs = text_splitter.split_documents(docs)

    vectorstore = vectorstore = FAISS.load_local("faiss_index", embeddings = HuggingFaceEmbeddings(), allow_dangerous_deserialization = True)
    retriever = vectorstore.as_retriever(search_type = "similarity", search_kwargs = {"k": 5})

    api_key = os.getenv("APIKEY")
    llm = ChatFireworks(model="accounts/fireworks/models/mixtral-8x7b-instruct", api_key = api_key)

    prompt = ChatPromptTemplate.from_messages(
        [
            (
                "system",
                """You are a knowledgeable agricultural assistant. If a disease is detected, you have to give information on the disease.
                If the image is not of a plant, ask human to upload image of a plant and stop generating any response.
                If the plant is healthy, just give maintenance tips. """
            ),
            (
                "human",
                """Provide information about the leaf disease in question in bullet points.
                Start your answer by mentioning the disease (if any) or healthy in this format: 'Condition: disease name'.
                """,
            ),

            ("human", "{context}, {question}"),
        ]
    )

    rag_chain = (
    {
    "context": retriever,
    "question": RunnablePassthrough()
    }
    | prompt
    | llm
    | StrOutputParser()
    )

    return rag_chain

def generate_response(rag_chain, query):
    """
    input: rag chain, query

    function: generates response using llm and knowledge base

    output: generated response by the llm
    """
    return rag_chain.invoke(f"{query}")

def main(image):
    query = image_to_query(image)
    chain = ragChain()
    output = generate_response(chain, query)
    return output

title = "Professor Bean: The Bean Disease Expert"
description = "Professor Bean is an agricultural expert. He will guide you on how to protect your plants from bean diseases"
app = gr.Interface(fn=main, 
                   inputs="image", 
                   outputs="text", 
                   title=title,
                   description=description,
                   examples=[["sampleImages/sample1.jpg"], ["sampleImages/sample2.jpg"],["sampleImages/sample3.jpg"], ["sampleImages/sample4.jpeg"]]
        )
app.launch(share=True)