gaur3009 commited on
Commit
4eb325f
·
verified ·
1 Parent(s): aed192b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -0
app.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import PyPDF2
3
+ from transformers import AutoTokenizer, AutoModel
4
+ import torch
5
+ import weaviate
6
+ import cohere
7
+
8
+ auth_config = weaviate.AuthApiKey(api_key="16LRz5YwOtnq8ov51Lhg1UuAollpsMgspulV")
9
+ client = weaviate.Client(
10
+ url="https://wkoll9rds3orbu9fhzfr2a.c0.asia-southeast1.gcp.weaviate.cloud",
11
+ auth_client_secret=auth_config
12
+ )
13
+ cohere_client = cohere.Client("LEvCVeZkqZMW1aLYjxDqlstCzWi4Cvlt9PiysqT8")
14
+
15
+ def load_pdf(file):
16
+ reader = PyPDF2.PdfReader(file)
17
+ text = ''
18
+ for page in range(len(reader.pages)):
19
+ text += reader.pages[page].extract_text()
20
+ return text
21
+
22
+ tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
23
+ model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
24
+
25
+ def get_embeddings(text):
26
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
27
+ with torch.no_grad():
28
+ embeddings = model(**inputs).last_hidden_state.mean(dim=1).squeeze().cpu().numpy()
29
+ return embeddings
30
+
31
+ def upload_document_chunks(chunks):
32
+ for idx, chunk in enumerate(chunks):
33
+ embedding = get_embeddings(chunk)
34
+ client.data_object.create(
35
+ {"content": chunk},
36
+ "Document",
37
+ vector=embedding.tolist()
38
+ )
39
+
40
+ def query_answer(query):
41
+ query_embedding = get_embeddings(query)
42
+ result = client.query.get("Document", ["content"])\
43
+ .with_near_vector({"vector": query_embedding.tolist()})\
44
+ .with_limit(3)\
45
+ .do()
46
+ return result
47
+
48
+ def generate_response(context, query):
49
+ response = cohere_client.generate(
50
+ model='command',
51
+ prompt=f"Context: {context}\n\nQuestion: {query}?\nAnswer:",
52
+ max_tokens=100
53
+ )
54
+ return response.generations[0].text.strip()
55
+
56
+ def qa_pipeline(pdf_file, query):
57
+ document_text = load_pdf(pdf_file)
58
+ document_chunks = [document_text[i:i+500] for i in range(0, len(document_text), 500)]
59
+
60
+ upload_document_chunks(document_chunks)
61
+
62
+ response = query_answer(query)
63
+ context = ' '.join([doc['content'] for doc in response['data']['Get']['Document']])
64
+
65
+ answer = generate_response(context, query)
66
+
67
+ return context, answer
68
+
69
+ with gr.Blocks() as demo:
70
+ gr.Markdown("# Interactive QA Bot")
71
+
72
+ pdf_input = gr.File(label="Upload a PDF file", file_types=[".pdf"])
73
+ query_input = gr.Textbox(label="Ask a question")
74
+
75
+ doc_segments_output = gr.Textbox(label="Retrieved Document Segments")
76
+ answer_output = gr.Textbox(label="Answer")
77
+
78
+ gr.Button("Submit").click(
79
+ qa_pipeline,
80
+ inputs=[pdf_input, query_input],
81
+ outputs=[doc_segments_output, answer_output]
82
+ )
83
+
84
+ demo.launch()