Pouria Aghaomidi commited on
Commit
0c5ef1f
·
1 Parent(s): d07e714

Add application file

Browse files
Files changed (1) hide show
  1. app.py +43 -0
app.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ import gradio as gr
4
+ from persian_rag import PersianRAG
5
+
6
+ os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
7
+
8
+
9
+ # Function to load CSV and initialize PersianRAG
10
+ def init_rag(knowledge_file, embedding_model, llm_model, device, retrieved_docs):
11
+ knowledge = pd.read_csv(knowledge_file)
12
+ rag_system = PersianRAG(knowledge, embedding_model=embedding_model, llm_model=llm_model, device=device,
13
+ retrieved_docs=retrieved_docs)
14
+ return rag_system
15
+
16
+
17
+ # Function to handle querying
18
+ def query_rag(rag_system, query):
19
+ return rag_system.rag(query)
20
+
21
+
22
+ # Gradio interface to upload CSV and configure RAG system
23
+ def rag_interface(knowledge_file, query, embedding_model, llm_model, device, retrieved_docs):
24
+ rag_system = init_rag(knowledge_file, embedding_model, llm_model, device, retrieved_docs)
25
+ return query_rag(rag_system, query)
26
+
27
+
28
+ # Create Gradio interface
29
+ interface = gr.Interface(
30
+ fn=rag_interface,
31
+ inputs=[
32
+ gr.File(label="Upload Knowledge Base CSV"),
33
+ gr.Textbox(label="Enter your query"),
34
+ gr.Dropdown(choices=["LABSE", "paraphrase-multilingual-mpnet-base-v2"], value="LABSE", label="Embedding Model"),
35
+ gr.Textbox(value="MehdiHosseiniMoghadam/AVA-Mistral-7B-V2", label="LLM Model Name"),
36
+ gr.Dropdown(choices=["cuda", "cpu"], value="cuda", label="Device"),
37
+ gr.Slider(minimum=1, maximum=5, step=1, value=3, label="Number of Retrieved Documents")],
38
+ outputs="text",
39
+ title="Persian RAG System",
40
+ description="Upload a CSV file as the knowledge base, ask a question, and get an answer.")
41
+
42
+ # Launch the Gradio interface
43
+ interface.launch(share=False)