imSleepy commited on
Commit
ebcb4b0
1 Parent(s): 42d4112

Uploaded chatbot.py

Browse files
Files changed (1) hide show
  1. chatbot.py +55 -0
chatbot.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import T5Tokenizer, T5ForConditionalGeneration
2
+ from sentence_transformers import SentenceTransformer
3
+ from pinecone import Pinecone
4
+
5
+ device = 'cpu'
6
+
7
+ # Initialize Pinecone instance
8
+ pc = Pinecone(api_key='89eeb534-da10-4068-92f7-12eddeabe1e5')
9
+
10
+ # Check if the index exists; if not, create it
11
+ index_name = 'abstractive-question-answering'
12
+ index = pc.Index(index_name)
13
+
14
+ def load_models():
15
+ print("Loading models...")
16
+
17
+ retriever = SentenceTransformer("flax-sentence-embeddings/all_datasets_v3_mpnet-base")
18
+ tokenizer = T5Tokenizer.from_pretrained('t5-base')
19
+ generator = T5ForConditionalGeneration.from_pretrained('t5-base').to(device)
20
+
21
+ return retriever, generator, tokenizer
22
+
23
+ retriever, generator, tokenizer = load_models()
24
+
25
+ def process_query(query):
26
+ # Query Pinecone
27
+ xq = retriever.encode([query]).tolist()
28
+ xc = index.query(vector=xq, top_k=1, include_metadata=True)
29
+
30
+ # Print the response to check the structure
31
+ print("Pinecone response:", xc)
32
+
33
+ # Check if 'matches' exists and is a list
34
+ if 'matches' in xc and isinstance(xc['matches'], list):
35
+ context = [m['metadata']['Output'] for m in xc['matches']]
36
+ context_str = " ".join(context)
37
+ formatted_query = f"answer the question: {query} context: {context_str}"
38
+ else:
39
+ # Handle the case where 'matches' isn't found or isn't in the expected format
40
+ context_str = ""
41
+ formatted_query = f"answer the question: {query} context: {context_str}"
42
+
43
+ # Generate answer using T5 model
44
+ output_text = context_str
45
+ if len(output_text.splitlines()) > 5:
46
+ return output_text
47
+
48
+ if output_text.lower() == "none":
49
+ return "The topic is not covered in the student manual."
50
+
51
+ inputs = tokenizer.encode(formatted_query, return_tensors="pt", max_length=512, truncation=True).to(device)
52
+ ids = generator.generate(inputs, num_beams=4, min_length=10, max_length=60, repetition_penalty=1.2)
53
+ answer = tokenizer.decode(ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
54
+
55
+ return answer