ayush0504 commited on
Commit
d397597
·
verified ·
1 Parent(s): ae40453

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -44
app.py CHANGED
@@ -1,46 +1,58 @@
 
 
 
1
  import streamlit as st
2
- from unsloth import FastLanguageModel
3
- from transformers import TextStreamer
4
-
5
- # Page Configuration
6
- st.set_page_config(page_title="AI Traffic Law Advisor", layout="wide")
7
-
8
- # Load the LoRA model
9
- MODEL_PATH = "./lora_model"
10
-
11
- @st.cache_resource(show_spinner=False)
12
- def load_model():
13
- # Load model and tokenizer
14
- model, tokenizer = FastLanguageModel.from_pretrained(
15
- MODEL_PATH,
16
- device_map="auto"
17
- )
18
- # Enable inference mode
19
- model = FastLanguageModel.for_inference(model)
20
- return model, tokenizer
21
-
22
- model, tokenizer = load_model()
23
-
24
- st.title("AI Traffic Law Advisor")
25
-
26
- user_query = st.text_area("Enter your legal question about traffic rules in India:", "")
27
-
28
- if st.button("Get Advice"):
29
- if user_query.strip():
30
- messages = [{"role": "user", "content": user_query}]
31
- # Tokenize input
32
- inputs = tokenizer.apply_chat_template(
33
- messages,
34
- tokenize=True,
35
- add_generation_prompt=True,
36
- return_tensors="pt"
37
- ).to(model.device)
38
-
39
- # Stream response
 
 
 
40
  text_streamer = TextStreamer(tokenizer, skip_prompt=True)
41
-
42
- st.markdown("**AI Response:**")
43
- with st.spinner("Generating response..."):
44
- model.generate(input_ids=inputs, streamer=text_streamer, max_new_tokens=1048, temperature=0.7)
45
- else:
46
- st.warning("Please enter a query.")
 
 
 
 
 
 
 
1
+ import torch
2
+ from peft import AutoPeftModelForCausalLM
3
+ from transformers import AutoTokenizer, TextStreamer
4
  import streamlit as st
5
+
6
+ # Initialize Streamlit UI
7
+ st.title("Legal Query Chatbot")
8
+ st.write("Ask questions related to Indian traffic laws and get AI-generated responses.")
9
+
10
+ # Load LoRA fine-tuned model and tokenizer
11
+ model_path = "lora_model"
12
+ load_in_4bit = True
13
+
14
+ # Load the model
15
+ model = AutoPeftModelForCausalLM.from_pretrained(
16
+ model_path,
17
+ torch_dtype=torch.float16 if not load_in_4bit else torch.float32,
18
+ load_in_4bit=load_in_4bit,
19
+ device_map="auto"
20
+ )
21
+
22
+ # Load tokenizer
23
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
24
+
25
+ # Enable inference mode
26
+ model.eval()
27
+
28
+ # Streamlit input for user prompt
29
+ user_input = st.text_input("Enter your legal query:", "What are the penalties for breaking a red light in India?")
30
+
31
+ if user_input:
32
+ # Prepare the prompt
33
+ messages = [{"role": "user", "content": user_input}]
34
+
35
+ # Tokenize input
36
+ inputs = tokenizer.apply_chat_template(
37
+ messages,
38
+ tokenize=True,
39
+ add_generation_prompt=True,
40
+ return_tensors="pt"
41
+ ).to("cuda" if torch.cuda.is_available() else "cpu")
42
+
43
+ # Streamlit progress indicator
44
+ with st.spinner("Generating response..."):
45
+ # Use a text streamer for efficient streaming output
46
  text_streamer = TextStreamer(tokenizer, skip_prompt=True)
47
+
48
+ # Generate response
49
+ output = model.generate(
50
+ input_ids=inputs,
51
+ streamer=text_streamer,
52
+ max_new_tokens=128,
53
+ use_cache=True,
54
+ temperature=1.5,
55
+ min_p=0.1
56
+ )
57
+
58
+ st.success("Generation Complete!")