raghav-gaggar commited on
Commit
cff1b65
·
verified ·
1 Parent(s): 3f23ae2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +115 -0
app.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
3
+ from datasets import concatenate_datasets, load_dataset
4
+ import gc
5
+ import gradio as gr
6
+ from peft import PeftModel, PeftConfig
7
+ from langchain.chains import RetrievalQA
8
+ from langchain_community.vectorstores import FAISS
9
+ from langchain.embeddings import OpenAIEmbeddings
10
+ from langchain.docstore.document import Document
11
+ from langchain.llms import HuggingFacePipeline
12
+ from langchain.embeddings import HuggingFaceEmbeddings
13
+ from langchain.chains.question_answering import load_qa_chain
14
+ from langchain.prompts import PromptTemplate
15
+ import torch
16
+ import random
17
+ from langchain.document_loaders import WebBaseLoader
18
+ from langchain.text_splitter import CharacterTextSplitter
19
+ from langchain.memory import ConversationBufferMemory
20
+ import requests
21
+ import re
22
+
23
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
+
25
+ # Load Samsum dataset for generating questions
26
+ train_dataset = load_dataset("samsum", split='train', trust_remote_code=True)
27
+ val_dataset = load_dataset("samsum", split='validation', trust_remote_code=True)
28
+ samsum_dataset = concatenate_datasets([train_dataset, val_dataset])
29
+
30
+ model_name = "google/flan-t5-base"
31
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
32
+ base_model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(device)
33
+ rlhf_model_path = "raghav-gaggar/PEFT_RLHF_TextSummarizer"
34
+ config = PeftConfig.from_pretrained(rlhf_model_path)
35
+ ppo_model = PeftModel.from_pretrained(base_model, rlhf_model_path).to(device)
36
+ merged_model = ppo_model.merge_and_unload().to(device)
37
+
38
+ base_model.eval()
39
+ ppo_model.eval()
40
+ merged_model.eval()
41
+
42
+ dialogsum_dataset = load_dataset("knkarthick/dialogsum", trust_remote_code=True)
43
+
44
+ def format_dialogsum_as_document(example):
45
+ return Document(page_content=f"Dialogue:\n {example['dialogue']}\n\nSummary: {example['summary']}")
46
+
47
+ # Create documents from DialogSum dataset
48
+ documents = []
49
+ for split in ['train', 'validation', 'test']:
50
+ documents.extend([format_dialogsum_as_document(example) for example in dialogsum_dataset[split]])
51
+
52
+ # Split the documents into chunks
53
+ text_splitter = CharacterTextSplitter(chunk_size=5200, chunk_overlap=0)
54
+ docs = text_splitter.split_documents(documents)
55
+
56
+ # Create embeddings and vector store for DialogSum documents
57
+ embeddings = HuggingFaceEmbeddings(
58
+ model_name="sentence-transformers/all-MiniLM-L6-v2",
59
+ model_kwargs={"device": "cuda" if torch.cuda.is_available() else "cpu"},
60
+ encode_kwargs={"batch_size": 32}
61
+ )
62
+
63
+ vector_store = FAISS.from_documents(docs, embeddings)
64
+
65
+ # Initialize retriever for DialogSum documents
66
+ retriever = vector_store.as_retriever(search_kwargs={"k": 1})
67
+
68
+ prompt_template = """
69
+ Concisely summarize the dialogue in the end, like the example provided -
70
+
71
+ Example -
72
+ {context}
73
+
74
+ Dialogue to be summarized:
75
+ {question}
76
+
77
+ Summary:"""
78
+
79
+ PROMPT = PromptTemplate(
80
+ template=prompt_template, input_variables=["context", "question"]
81
+ )
82
+
83
+ # Create a Hugging Face pipeline
84
+ summarization_pipeline = pipeline(
85
+ "summarization",
86
+ model=merged_model,
87
+ tokenizer=tokenizer,
88
+ max_length=150,
89
+ min_length=20,
90
+ do_sample=False,
91
+ )
92
+
93
+ # Wrap the pipeline in a LangChain LLM
94
+ llm = HuggingFacePipeline(pipeline=summarization_pipeline)
95
+
96
+ qa_chain = RetrievalQA.from_chain_type(
97
+ llm, retriever=retriever, chain_type_kwargs={"prompt": PROMPT}
98
+ )
99
+
100
+ # Function for Gradio interface
101
+ def summarize_conversation(question):
102
+ result = qa_chain({"query": question})
103
+ return result["result"]
104
+
105
+ # Create Gradio interface
106
+ iface = gr.Interface(
107
+ fn=summarize_conversation,
108
+ inputs=gr.Textbox(lines=10, label="Enter conversation here"),
109
+ outputs=gr.Textbox(label="Summary"),
110
+ title="Conversation Summarizer",
111
+ description="Enter a conversation, and the AI will provide a concise summary."
112
+ )
113
+
114
+ # Launch the app
115
+ iface.launch()