nprasad24 commited on
Commit
bfe7e73
·
verified ·
1 Parent(s): fd714c4

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +137 -0
app.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ subprocess.run('pip install -r requirements.txt', shell = True)
3
+
4
+ import gradio as gr
5
+ import os
6
+ from PIL import Image
7
+ import numpy as np
8
+ from transformers import pipeline
9
+ import transformers
10
+ transformers.logging.set_verbosity_error()
11
+ from langchain_community.document_loaders import TextLoader
12
+ from langchain_community.vectorstores import FAISS
13
+ from langchain_community.embeddings import HuggingFaceEmbeddings
14
+ from langchain.text_splitter import CharacterTextSplitter
15
+ from langchain_core.output_parsers import StrOutputParser
16
+ from langchain_core.runnables import RunnablePassthrough
17
+ from langchain_fireworks import ChatFireworks
18
+ from langchain_community.llms import Ollama
19
+ from langchain_core.prompts import ChatPromptTemplate
20
+ from rich.console import Console
21
+ from rich.markdown import Markdown
22
+
23
+ transformers.logging.set_verbosity_error()
24
+
25
+
26
+
27
+ def image_to_query(image):
28
+ """
29
+ input: Image
30
+
31
+ function: Performs image classification using fine-tuned model
32
+
33
+ output: Query for the LLM
34
+ """
35
+ image = Image.open('test.jpg')
36
+ classifier = pipeline("image-classification", model = "nprasad24/bean_classifier")
37
+
38
+ scores = classifier(image)
39
+
40
+ # Get the dictionary with the maximum score
41
+ max_score_dict = max(scores, key=lambda x: x['score'])
42
+
43
+ # Extract the label with the maximum score
44
+ label_with_max_score = max_score_dict['label']
45
+
46
+ # script to check if the image uploaded is indeed a leaf or not
47
+ counter = 0
48
+ for ele in scores:
49
+ if 0.2 <= ele['score'] <= 0.4:
50
+ counter += 1
51
+
52
+ if label_with_max_score == 'healthy' and counter != 3:
53
+ query = "The plant is healthy. Give tips on maintaining the plant"
54
+ elif label_with_max_score == 'bean_rust' and counter != 3:
55
+ query = "The detected disease is bean rust. Explain the disease"
56
+ elif label_with_max_score == 'angular_leaf_spot' and counter != 3:
57
+ query = "The detected disease is angular leaf spot. Explain the disease"
58
+ else:
59
+ query = "Given image is not of a plant."
60
+
61
+ return query
62
+
63
+ def ragChain():
64
+ """
65
+ function: creates a rag chain
66
+
67
+ output: rag chain
68
+ """
69
+ loader = TextLoader("knowledgeBase.txt")
70
+ docs = loader.load()
71
+
72
+ text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
73
+ docs = text_splitter.split_documents(docs)
74
+
75
+ vectorstore = FAISS.from_documents(documents = docs, embedding = HuggingFaceEmbeddings())
76
+ retriever = vectorstore.as_retriever(search_type = "similarity", search_kwargs = {"k": 5})
77
+
78
+ APIKEY = "o7T3gVx9Vt8GSJbLyPV1974vF8LXVp01CWqOkWQuHgoHm07H"
79
+ os.environ["FIREWORKS_API_KEY"] = APIKEY
80
+
81
+ llm = ChatFireworks(model="accounts/fireworks/models/mixtral-8x7b-instruct")
82
+
83
+ prompt = ChatPromptTemplate.from_messages(
84
+ [
85
+ (
86
+ "system",
87
+ """You are a knowledgeable agricultural assistant. If a disease is detected, you have to give information on the disease.
88
+ If the plant is healthy, just give maintenance tips. """
89
+ ),
90
+ (
91
+ "human",
92
+ """Provide information about the leaf disease in question in bullet points.
93
+ Start your answer by mentioning the disease (if any) or healthy in this format: 'Condition: disease name'.
94
+ If the image is not of a plant, ask human to upload image of a plant and stop generating any response.
95
+ """,
96
+ ),
97
+
98
+ ("human", "{context}, {question}"),
99
+ ]
100
+ )
101
+
102
+ rag_chain = (
103
+ {
104
+ "context": retriever,
105
+ "question": RunnablePassthrough()
106
+ }
107
+ | prompt
108
+ | llm
109
+ | StrOutputParser()
110
+ )
111
+
112
+ return rag_chain
113
+
114
+ def generate_response(rag_chain, query):
115
+ """
116
+ input: rag chain, query
117
+
118
+ function: generates response using llm and knowledge base
119
+
120
+ output: generated response by the llm
121
+ """
122
+ return Markdown(rag_chain.invoke(f"{query}"))
123
+
124
+ def main():
125
+ console = Console()
126
+
127
+ query = image_to_query('test2.jpeg')
128
+ chain = ragChain()
129
+ output = generate_response(chain, query)
130
+ output = Markdown(output)
131
+ return output
132
+
133
+ title = "Bean Classifier and Instructor"
134
+ description = "Professor Bean is an agricultural expert. He will guide you on how to protect your plants from bean diseases"
135
+ app = gr.Interface(fn=main, inputs="image", outputs="text", title=title,
136
+ description=description)
137
+ app.launch(share=True)