legacy107 commited on
Commit
2cc4144
·
1 Parent(s): ebe2f40

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +132 -0
main.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from gradio.components import Textbox
3
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, T5ForConditionalGeneration
4
+ from peft import PeftModel
5
+ import torch
6
+ import datasets
7
+ from sentence_transformers import CrossEncoder
8
+ import math
9
+ import re
10
+ from nltk import sent_tokenize, word_tokenize
11
+ import nltk
12
+ nltk.download('punkt')
13
+
14
+ # Load bi encoder
15
+ bi_encoder = SentenceTransformer('legacy107/multi-qa-mpnet-base-dot-v1-wikipedia-search')
16
+ bi_encoder.max_seq_length = 256
17
+ top_k = 3
18
+
19
+ # Load your fine-tuned model and tokenizer
20
+ model_name = "legacy107/flan-t5-large-ia3-wiki-merged"
21
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
22
+ model = T5ForConditionalGeneration.from_pretrained(model_name)
23
+ max_length = 512
24
+ max_target_length = 200
25
+
26
+ # Load your dataset
27
+ dataset = datasets.load_dataset("legacy107/qa_wikipedia_retrieved_chunks", split="test")
28
+ dataset = dataset.shuffle()
29
+ dataset = dataset.select(range(10))
30
+
31
+ # Context chunking
32
+ def chunk_splitter(context, chunk_size=100, overlap=0.20):
33
+ overlap_size = chunk_size * overlap
34
+ sentences = nltk.sent_tokenize(context)
35
+
36
+ chunks = []
37
+ text = sentences[0]
38
+
39
+ if len(sentences) == 1:
40
+ chunks.append(text)
41
+
42
+ i = 1
43
+ while i < len(sentences):
44
+ text += " " + sentences[i]
45
+ i += 1
46
+ while i < len(sentences) and len(nltk.word_tokenize(f"{text} {sentences[i]}")) <= chunk_size:
47
+ text += " " + sentences[i]
48
+ i += 1
49
+
50
+ text = text.replace('\"','"').replace("\'","'").replace('\n\n\n'," ").replace('\n\n'," ").replace('\n'," ")
51
+ chunks.append(text)
52
+
53
+ if (i >= len(sentences)):
54
+ break
55
+
56
+ j = i - 1
57
+ text = sentences[j]
58
+ while j >= 0 and len(nltk.word_tokenize(f"{sentences[j]} {text}")) <= overlap_size:
59
+ text = sentences[j] + " " + text
60
+ j -= 1
61
+
62
+ return chunks
63
+
64
+
65
+ def retrieve_context(query, contexts):
66
+ corpus_embeddings = bi_encoder.encode(contexts, convert_to_tensor=True, show_progress_bar=False)
67
+
68
+ question_embedding = bi_encoder.encode(query, convert_to_tensor=True, show_progress_bar=False)
69
+ question_embedding = question_embedding.cuda()
70
+ hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k)
71
+ hits = hits[0]
72
+
73
+ hits = sorted(hits, key=lambda x: x['score'], reverse=True)
74
+ return " ".join([contexts[hit['corpus_id']] for hit in hits[0:top_k]]).replace("\n", " ")
75
+
76
+
77
+ # Define your function to generate answers
78
+ def generate_answer(question, context, ground):
79
+ contexts = chunk_splitter(clean_data(context))
80
+ context = retrieve_context(question, contexts)
81
+
82
+ # Combine question and context
83
+ input_text = f"question: {question} context: {context}"
84
+
85
+ # Tokenize the input text
86
+ input_ids = tokenizer(
87
+ input_text,
88
+ return_tensors="pt",
89
+ padding="max_length",
90
+ truncation=True,
91
+ max_length=max_length,
92
+ ).input_ids
93
+
94
+ # Generate the answer
95
+ with torch.no_grad():
96
+ generated_ids = model.generate(input_ids=input_ids, max_new_tokens=max_target_length)
97
+
98
+ # Decode and return the generated answer
99
+ generated_answer = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
100
+
101
+ return generated_answer, context, ground
102
+
103
+
104
+ # Define a function to list examples from the dataset
105
+ def list_examples():
106
+ examples = []
107
+ for example in dataset:
108
+ context = example["article"]
109
+ question = example["question"]
110
+ answer = example["answer"]
111
+ examples.append([question, context, answer])
112
+ return examples
113
+
114
+
115
+ # Create a Gradio interface
116
+ iface = gr.Interface(
117
+ fn=generate_answer,
118
+ inputs=[
119
+ Textbox(label="Question"),
120
+ Textbox(label="Context"),
121
+ Textbox(label="Ground truth")
122
+ ],
123
+ outputs=[
124
+ Textbox(label="Generated Answer"),
125
+ Textbox(label="Retrieved Context"),
126
+ Textbox(label="Ground Truth")
127
+ ],
128
+ examples=list_examples()
129
+ )
130
+
131
+ # Launch the Gradio interface
132
+ iface.launch()