checkiejan commited on
Commit
4ad99c8
·
1 Parent(s): 8050cb3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +104 -0
app.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from gradio.components import Textbox
3
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, T5ForConditionalGeneration
4
+ from peft import PeftModel, PeftConfig
5
+ import torch
6
+ import datasets
7
+
8
+ # Load your fine-tuned model and tokenizer
9
+ model_name = "google/flan-t5-large"
10
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
11
+ model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large")
12
+ model.load_adapter("legacy107/adapter-flan-t5-large-bottleneck-adapter-cpgQA", source="hf")
13
+ model.set_active_adapters("question_answering")
14
+
15
+ peft_name = "legacy107/flan-t5-large-ia3-bioasq-paraphrase"
16
+ peft_config = PeftConfig.from_pretrained(peft_name)
17
+ paraphrase_model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
18
+ paraphrase_model = PeftModel.from_pretrained(paraphrase_model, peft_name)
19
+
20
+ max_length = 512
21
+ max_target_length = 128
22
+
23
+ # Load your dataset
24
+ dataset = datasets.load_dataset("minh21/cpgQA-v1.0-unique-context-test-10-percent-validation-10-percent", split="test")
25
+ dataset = dataset.shuffle()
26
+ dataset = dataset.select(range(5))
27
+
28
+
29
+ def paraphrase_answer(question, answer):
30
+ # Combine question and context
31
+ input_text = f"question: {question}. Paraphrase the answer to make it more natural answer: {answer}"
32
+
33
+ # Tokenize the input text
34
+ input_ids = tokenizer(
35
+ input_text,
36
+ return_tensors="pt",
37
+ padding="max_length",
38
+ truncation=True,
39
+ max_length=max_length,
40
+ ).input_ids
41
+
42
+ # Generate the answer
43
+ with torch.no_grad():
44
+ generated_ids = paraphrase_model.generate(input_ids=input_ids, max_new_tokens=max_target_length)
45
+
46
+ # Decode and return the generated answer
47
+ paraphrased_answer = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
48
+
49
+ return paraphrased_answer
50
+
51
+
52
+ # Define your function to generate answers
53
+ def generate_answer(question, context):
54
+ # Combine question and context
55
+ input_text = f"question: {question} context: {context}"
56
+
57
+ # Tokenize the input text
58
+ input_ids = tokenizer(
59
+ input_text,
60
+ return_tensors="pt",
61
+ padding="max_length",
62
+ truncation=True,
63
+ max_length=max_length,
64
+ ).input_ids
65
+
66
+ # Generate the answer
67
+ with torch.no_grad():
68
+ generated_ids = model.generate(input_ids, max_new_tokens=max_target_length)
69
+
70
+ # Decode and return the generated answer
71
+ generated_answer = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
72
+
73
+ # Paraphrase answer
74
+ paraphrased_answer = paraphrase_answer(question, generated_answer)
75
+
76
+ return generated_answer, paraphrased_answer
77
+
78
+
79
+ # Define a function to list examples from the dataset
80
+ def list_examples():
81
+ examples = []
82
+ for example in dataset:
83
+ context = example["context"]
84
+ question = example["question"]
85
+ examples.append([question, context])
86
+ return examples
87
+
88
+
89
+ # Create a Gradio interface
90
+ iface = gr.Interface(
91
+ fn=generate_answer,
92
+ inputs=[
93
+ Textbox(label="Question"),
94
+ Textbox(label="Context")
95
+ ],
96
+ outputs=[
97
+ Textbox(label="Generated Answer"),
98
+ Textbox(label="Natural Answer")
99
+ ],
100
+ examples=list_examples()
101
+ )
102
+
103
+ # Launch the Gradio interface
104
+ iface.launch()