suchinth08 commited on
Commit
8ecb9c7
·
verified ·
1 Parent(s): 9a092af

Upload lawchain.py

Browse files
Files changed (1) hide show
  1. lawchain.py +58 -0
lawchain.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import transformers
2
+ import torch
3
+ import os
4
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
5
+ from transformers import pipeline
6
+ from langchain.llms import HuggingFacePipeline
7
+ from langchain.vectorstores import Chroma
8
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
9
+ from langchain.chains import RetrievalQA
10
+ from langchain.document_loaders import TextLoader
11
+ from langchain.document_loaders import PyPDFLoader
12
+ from langchain.document_loaders import DirectoryLoader
13
+ from InstructorEmbedding import INSTRUCTOR
14
+ from langchain.embeddings import HuggingFaceInstructEmbeddings
15
+ from langchain_community.vectorstores import Chroma
16
+ import textwrap
17
+ import streamlit as st
18
+
19
+ persist_directory = 'db'
20
+ instructor_embeddings = HuggingFaceInstructEmbeddings(model_name="hkunlp/instructor-base")
21
+ embedding = instructor_embeddings
22
+ tokenizer = AutoTokenizer.from_pretrained("lmsys/fastchat-t5-3b-v1.0")
23
+ model = AutoModelForSeq2SeqLM.from_pretrained("lmsys/fastchat-t5-3b-v1.0")
24
+ pipe = pipeline("text2text-generation",model=model, tokenizer=tokenizer, max_length=256)
25
+ local_llm = HuggingFacePipeline(pipeline=pipe)
26
+ vectordb = Chroma(persist_directory=persist_directory,embedding_function=embedding)
27
+ retriever = vectordb.as_retriever(search_kwargs={"k": 3})
28
+
29
+ def get_lpphelper_chain():
30
+ qa_chain = RetrievalQA.from_chain_type(llm=local_llm,
31
+ chain_type="stuff",
32
+ retriever=retriever,
33
+ return_source_documents=True)
34
+ return qa_chain
35
+
36
+ def wrap_text_preserve_newlines(text, width=110):
37
+ # Split the input text into lines based on newline characters
38
+ lines = text.split('\n')
39
+
40
+ # Wrap each line individually
41
+ wrapped_lines = [textwrap.fill(line, width=width) for line in lines]
42
+
43
+ # Join the wrapped lines back together using newline characters
44
+ wrapped_text = '\n'.join(wrapped_lines)
45
+
46
+ return wrapped_text
47
+
48
+ def process_llm_response(llm_response):
49
+ wrap_text = wrap_text_preserve_newlines(llm_response['result'])
50
+ sources = '\n\nSources:'
51
+ print('\n\nSources:')
52
+ for source in llm_response["source_documents"]:
53
+ sources.join(source.metadata['source'])
54
+ print(wrap_text.join(sources))
55
+ return wrap_text.replace("<pad>","")
56
+
57
+ if __name__=="__main__":
58
+ get_lpphelper_chain()