suchinth08 commited on
Commit
d4ed260
·
verified ·
1 Parent(s): adce4c9

Upload 4 files

Browse files
Files changed (4) hide show
  1. README.md +3 -12
  2. lawmain.py +29 -0
  3. lppchain.py +58 -0
  4. lpphelper.py +50 -0
README.md CHANGED
@@ -1,12 +1,3 @@
1
- ---
2
- title: Lawllm
3
- emoji: 📚
4
- colorFrom: red
5
- colorTo: purple
6
- sdk: streamlit
7
- sdk_version: 1.31.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ # lawllm
2
+ Law LLM Model to work on Indian Judiciary Acts, Orders, Provisions and Citations
3
+ ![image](https://github.com/suchinth08/lawllm/assets/21136148/9e47d810-c3b4-487a-9663-07ad9b3186a5)
 
 
 
 
 
 
 
 
 
lawmain.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+ from lppchain import get_lpphelper_chain,process_llm_response
4
+
5
+ #st.title( "Lakna Reddy & Associates 🤖")
6
+ col1, mid, col2 = st.columns(3)
7
+ image = Image.open('lawimage2.jpg')
8
+ with col1:
9
+ st.image(image, width=150)
10
+ with col2:
11
+ st.markdown("## Lakna Reddy & Associates")
12
+
13
+ question = st.text_input("Question: ")
14
+ @st.cache_resource
15
+ def load_qa_chain():
16
+ chain = get_lpphelper_chain()
17
+ return chain
18
+
19
+ if question:
20
+ chain = load_qa_chain()
21
+ #response = chain.run(question)
22
+ #llm_response = process_llm_response(response)
23
+ with st.spinner('Generating response...'):
24
+ response = chain.invoke(question)
25
+ print(response)
26
+ #answer = response['result']
27
+ answer = process_llm_response(response)
28
+ st.header("Answer")
29
+ st.write(answer.replace("<pad>",""))
lppchain.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import transformers
2
+ import torch
3
+ import os
4
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
5
+ from transformers import pipeline
6
+ from langchain.llms import HuggingFacePipeline
7
+ from langchain.vectorstores import Chroma
8
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
9
+ from langchain.chains import RetrievalQA
10
+ from langchain.document_loaders import TextLoader
11
+ from langchain.document_loaders import PyPDFLoader
12
+ from langchain.document_loaders import DirectoryLoader
13
+ from InstructorEmbedding import INSTRUCTOR
14
+ from langchain.embeddings import HuggingFaceInstructEmbeddings
15
+ from langchain_community.vectorstores import Chroma
16
+ import textwrap
17
+ import streamlit as st
18
+
19
+ persist_directory = 'db'
20
+ instructor_embeddings = HuggingFaceInstructEmbeddings(model_name="hkunlp/instructor-base")
21
+ embedding = instructor_embeddings
22
+ tokenizer = AutoTokenizer.from_pretrained("lmsys/fastchat-t5-3b-v1.0")
23
+ model = AutoModelForSeq2SeqLM.from_pretrained("lmsys/fastchat-t5-3b-v1.0")
24
+ pipe = pipeline("text2text-generation",model=model, tokenizer=tokenizer, max_length=256)
25
+ local_llm = HuggingFacePipeline(pipeline=pipe)
26
+ vectordb = Chroma(persist_directory=persist_directory,embedding_function=embedding)
27
+ retriever = vectordb.as_retriever(search_kwargs={"k": 3})
28
+
29
+ def get_lpphelper_chain():
30
+ qa_chain = RetrievalQA.from_chain_type(llm=local_llm,
31
+ chain_type="stuff",
32
+ retriever=retriever,
33
+ return_source_documents=True)
34
+ return qa_chain
35
+
36
+ def wrap_text_preserve_newlines(text, width=110):
37
+ # Split the input text into lines based on newline characters
38
+ lines = text.split('\n')
39
+
40
+ # Wrap each line individually
41
+ wrapped_lines = [textwrap.fill(line, width=width) for line in lines]
42
+
43
+ # Join the wrapped lines back together using newline characters
44
+ wrapped_text = '\n'.join(wrapped_lines)
45
+
46
+ return wrapped_text
47
+
48
+ def process_llm_response(llm_response):
49
+ wrap_text = wrap_text_preserve_newlines(llm_response['result'])
50
+ sources = '\n\nSources:'
51
+ print('\n\nSources:')
52
+ for source in llm_response["source_documents"]:
53
+ sources.join(source.metadata['source'])
54
+ print(wrap_text.join(sources))
55
+ return wrap_text.replace("<pad>","")
56
+
57
+ if __name__=="__main__":
58
+ get_lpphelper_chain()
lpphelper.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import transformers
2
+ import torch
3
+ import os
4
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
5
+ from transformers import pipeline
6
+ from langchain.llms import HuggingFacePipeline
7
+ from langchain.vectorstores import Chroma
8
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
9
+ from langchain.chains import RetrievalQA
10
+ from langchain.document_loaders import TextLoader
11
+ from langchain.document_loaders import PyPDFLoader
12
+ from langchain.document_loaders import DirectoryLoader
13
+ from InstructorEmbedding import INSTRUCTOR
14
+ from langchain.embeddings import HuggingFaceInstructEmbeddings
15
+ from langchain_community.vectorstores import Chroma
16
+ import textwrap
17
+
18
+ def gen_vectordb():
19
+ tokenizer = AutoTokenizer.from_pretrained("lmsys/fastchat-t5-3b-v1.0")
20
+ model = AutoModelForSeq2SeqLM.from_pretrained("lmsys/fastchat-t5-3b-v1.0")
21
+ pipe = pipeline(
22
+ "text2text-generation",
23
+ model=model,
24
+ tokenizer=tokenizer,
25
+ max_length=256
26
+ )
27
+
28
+ local_llm = HuggingFacePipeline(pipeline=pipe)
29
+ loader = DirectoryLoader('C:/Users/SudheerRChinthala/sivallm/new_papers', glob="./*.pdf", loader_cls=PyPDFLoader)
30
+ documents = loader.load()
31
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
32
+ texts = text_splitter.split_documents(documents)
33
+
34
+ instructor_embeddings = HuggingFaceInstructEmbeddings(model_name="hkunlp/instructor-base")
35
+ persist_directory = 'db'
36
+ embedding = instructor_embeddings
37
+ vectordb = Chroma.from_documents(documents=texts,
38
+ embedding=embedding,
39
+ persist_directory=persist_directory)
40
+ retriever = vectordb.as_retriever(search_kwargs={"k": 3})
41
+ qa_chain = RetrievalQA.from_chain_type(llm=local_llm,
42
+ chain_type="stuff",
43
+ retriever=retriever,
44
+ return_source_documents=True)
45
+ vectordb.persist()
46
+ vectordb = None
47
+
48
+
49
+ if __name__=="__main__":
50
+ gen_vectordb()