ajeetkumar01 commited on
Commit
51391bc
·
verified ·
1 Parent(s): 9ace159

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +101 -0
app.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ import torch
5
+ from huggingface_hub import login # For authentication
6
+
7
+ # Authenticate with Hugging Face
8
+ def authenticate_huggingface():
9
+ token = os.getenv("llama2_token") # Load token from environment variable
10
+ if token:
11
+ login(token) # This logs in using the Hugging Face token
12
+ else:
13
+ st.error("Hugging Face token not found. Please set the HF_TOKEN environment variable.")
14
+
15
+ # Load the Llama 2 model from Hugging Face
16
+ @st.cache_resource
17
+ def load_llama_model():
18
+ authenticate_huggingface() # Ensure authentication is done before loading
19
+ model_name = "meta-llama/Llama-2-7b-hf"
20
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=True)
21
+ model = AutoModelForCausalLM.from_pretrained(model_name, use_auth_token=True)
22
+ return tokenizer, model
23
+
24
+ # Function to query the Llama 2 model
25
+ def query_llama_model(penal_code, tokenizer, model):
26
+ prompt = f"What is California Penal Code {penal_code}?"
27
+
28
+ # Tokenize the input prompt
29
+ inputs = tokenizer(prompt, return_tensors="pt")
30
+
31
+ # Generate output from the model
32
+ outputs = model.generate(**inputs, max_new_tokens=100)
33
+
34
+ # Decode the generated text
35
+ description = tokenizer.decode(outputs[0], skip_special_tokens=True)
36
+ return description
37
+
38
+ # Function to process CSV and update descriptions
39
+ def update_csv_with_descriptions(csv_file, tokenizer, model):
40
+ # Read the CSV file
41
+ df = pd.read_csv(csv_file)
42
+
43
+ # Dictionary to store penal codes and their descriptions
44
+ penal_code_dict = {}
45
+
46
+ # Iterate through each row in the CSV
47
+ for index, row in df.iterrows():
48
+ penal_code = row['Offense Number']
49
+
50
+ # Check if description is already present
51
+ if not row['Description']:
52
+ st.write(f"Querying description for {penal_code}...")
53
+ description = query_llama_model(penal_code, tokenizer, model)
54
+
55
+ # Update the dataframe with the description
56
+ df.at[index, 'Description'] = description
57
+
58
+ # Add to dictionary
59
+ penal_code_dict[penal_code] = description
60
+
61
+ # Save the updated CSV file
62
+ updated_file_path = 'updated_' + csv_file.name
63
+ df.to_csv(updated_file_path, index=False)
64
+
65
+ return penal_code_dict, updated_file_path
66
+
67
+ # Streamlit UI
68
+ def main():
69
+ st.title("Penal Code Description Extractor with Llama 2")
70
+
71
+ # Load the Llama 2 model and tokenizer
72
+ tokenizer, model = load_llama_model()
73
+
74
+ # Upload CSV file
75
+ uploaded_file = st.file_uploader("Upload a CSV file with Penal Codes", type=["csv"])
76
+
77
+ if uploaded_file is not None:
78
+ # Display uploaded file
79
+ st.write("Uploaded CSV File:")
80
+ df = pd.read_csv(uploaded_file)
81
+ st.dataframe(df)
82
+
83
+ # Process the file and update descriptions
84
+ if st.button("Get Penal Code Descriptions"):
85
+ penal_code_dict, updated_file_path = update_csv_with_descriptions(uploaded_file, tokenizer, model)
86
+
87
+ # Show dictionary output
88
+ st.write("Penal Code Descriptions:")
89
+ st.json(penal_code_dict)
90
+
91
+ # Provide a download link for the updated CSV
92
+ with open(updated_file_path, 'rb') as f:
93
+ st.download_button(
94
+ label="Download Updated CSV",
95
+ data=f,
96
+ file_name=updated_file_path,
97
+ mime='text/csv'
98
+ )
99
+
100
+ if __name__ == "__main__":
101
+ main()