adnanaman commited on
Commit
b2d81b0
·
verified ·
1 Parent(s): 2163c58

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +215 -209
app.py CHANGED
@@ -1,209 +1,215 @@
1
- # streamlit_tapex_app.py
2
-
3
- import streamlit as st
4
- import pandas as pd
5
- import torch
6
- from transformers import TapexTokenizer, BartForConditionalGeneration
7
- import xml.etree.ElementTree as ET
8
- from io import StringIO
9
- import logging
10
- from datetime import datetime
11
- import time
12
-
13
- # Configure logging
14
- logging.basicConfig(
15
- level=logging.INFO,
16
- format='%(asctime)s - %(levelname)s - %(message)s'
17
- )
18
- logger = logging.getLogger(__name__)
19
-
20
- @st.cache_resource
21
- def load_model():
22
- """
23
- Load and cache the TAPEX model and tokenizer using Streamlit's caching
24
- """
25
- try:
26
- tokenizer = TapexTokenizer.from_pretrained(
27
- "microsoft/tapex-large-finetuned-wtq",
28
- model_max_length=1024
29
- )
30
- model = BartForConditionalGeneration.from_pretrained(
31
- "microsoft/tapex-large-finetuned-wtq"
32
- )
33
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
34
- model = model.to(device)
35
- model.eval()
36
- return tokenizer, model
37
- except Exception as e:
38
- st.error(f"Error loading model: {str(e)}")
39
- return None, None
40
-
41
- def parse_xml_to_dataframe(xml_string: str):
42
- """
43
- Parse XML string to DataFrame with error handling
44
- """
45
- try:
46
- tree = ET.parse(StringIO(xml_string))
47
- root = tree.getroot()
48
-
49
- data = []
50
- columns = set()
51
-
52
- # First pass: collect all possible columns
53
- for record in root.findall('.//record'):
54
- columns.update(elem.tag for elem in record)
55
-
56
- # Second pass: create data rows
57
- for record in root.findall('.//record'):
58
- row_data = {col: None for col in columns}
59
- for elem in record:
60
- row_data[elem.tag] = elem.text
61
- data.append(row_data)
62
-
63
- df = pd.DataFrame(data)
64
-
65
- # Convert numeric columns (automatically detect)
66
- for col in df.columns:
67
- try:
68
- df[col] = pd.to_numeric(df[col])
69
- except:
70
- continue
71
-
72
- return df, None
73
- except Exception as e:
74
- return None, f"Error parsing XML: {str(e)}"
75
-
76
- def process_query(tokenizer, model, df, query: str):
77
- """
78
- Process a single query using the TAPEX model
79
- """
80
- try:
81
- start_time = time.time()
82
-
83
- # Handle direct DataFrame operations for common queries
84
- query_lower = query.lower()
85
- if "highest" in query_lower or "maximum" in query_lower:
86
- for col in df.select_dtypes(include=['number']).columns:
87
- if col.lower() in query_lower:
88
- return df.loc[df[col].idxmax()].to_dict()
89
- elif "average" in query_lower or "mean" in query_lower:
90
- for col in df.select_dtypes(include=['number']).columns:
91
- if col.lower() in query_lower:
92
- return f"Average {col}: {df[col].mean():.2f}"
93
- elif "total" in query_lower or "sum" in query_lower:
94
- for col in df.select_dtypes(include=['number']).columns:
95
- if col.lower() in query_lower:
96
- return f"Total {col}: {df[col].sum():.2f}"
97
-
98
- # Use TAPEX for more complex queries
99
- with torch.no_grad():
100
- encoding = tokenizer(
101
- table=df.astype(str),
102
- query=query,
103
- return_tensors="pt",
104
- padding=True,
105
- truncation=True
106
- )
107
- outputs = model.generate(**encoding)
108
- answer = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
109
-
110
- processing_time = time.time() - start_time
111
- return f"Answer: {answer} (Processing time: {processing_time:.2f}s)"
112
-
113
- except Exception as e:
114
- return f"Error processing query: {str(e)}"
115
-
116
- def main():
117
- st.title("XML Data Query System")
118
- st.write("Upload your XML data and ask questions about it!")
119
-
120
- # Initialize session state for XML input if not exists
121
- if 'xml_input' not in st.session_state:
122
- st.session_state.xml_input = ""
123
-
124
- # Load model
125
- with st.spinner("Loading TAPEX model... (this may take a few moments)"):
126
- tokenizer, model = load_model()
127
- if tokenizer is None or model is None:
128
- st.error("Failed to load the model. Please refresh the page.")
129
- return
130
-
131
- # XML Input
132
- xml_input = st.text_area(
133
- "Enter your XML data here:",
134
- value=st.session_state.xml_input,
135
- height=200,
136
- help="Paste your XML data here. Make sure it's properly formatted."
137
- )
138
-
139
- # Sample XML button
140
- if st.button("Load Sample XML"):
141
- st.session_state.xml_input = """<?xml version="1.0" encoding="UTF-8"?>
142
- <data>
143
- <records>
144
- <record>
145
- <company>Apple</company>
146
- <revenue>365.7</revenue>
147
- <employees>147000</employees>
148
- <year>2021</year>
149
- </record>
150
- <record>
151
- <company>Microsoft</company>
152
- <revenue>168.1</revenue>
153
- <employees>181000</employees>
154
- <year>2021</year>
155
- </record>
156
- <record>
157
- <company>Amazon</company>
158
- <revenue>386.1</revenue>
159
- <employees>1608000</employees>
160
- <year>2021</year>
161
- </record>
162
- </records>
163
- </data>"""
164
- st.rerun()
165
-
166
- if xml_input:
167
- df, error = parse_xml_to_dataframe(xml_input)
168
- if error:
169
- st.error(error)
170
- else:
171
- st.success("XML parsed successfully!")
172
-
173
- # Display DataFrame
174
- st.subheader("Parsed Data:")
175
- st.dataframe(df)
176
-
177
- # Query input
178
- query = st.text_input(
179
- "Enter your question about the data:",
180
- help="Example: 'Which company has the highest revenue?'"
181
- )
182
-
183
- # Process query
184
- if query:
185
- with st.spinner("Processing query..."):
186
- result = process_query(tokenizer, model, df, query)
187
- st.write(result)
188
-
189
- # Sample queries
190
- st.subheader("Sample Questions:")
191
- sample_queries = [
192
- "Which company has the highest revenue?",
193
- "What is the average revenue of all companies?",
194
- "How many employees does Microsoft have?",
195
- "Which company has the most employees?",
196
- "What is the total revenue of all companies?"
197
- ]
198
-
199
- # Create columns for sample query buttons
200
- cols = st.columns(len(sample_queries))
201
- for idx, (col, sample_query) in enumerate(zip(cols, sample_queries)):
202
- with col:
203
- if st.button(f"Query {idx + 1}", help=sample_query):
204
- with st.spinner("Processing query..."):
205
- result = process_query(tokenizer, model, df, sample_query)
206
- st.write(result)
207
-
208
- if __name__ == "__main__":
209
- main()
 
 
 
 
 
 
 
1
+
2
+ import streamlit as st
3
+ import pandas as pd
4
+ import torch
5
+ from transformers import TapexTokenizer, BartForConditionalGeneration
6
+ import xml.etree.ElementTree as ET
7
+ from io import StringIO
8
+ import logging
9
+ from datetime import datetime
10
+ import time
11
+
12
+ # Configure logging
13
+ logging.basicConfig(
14
+ level=logging.INFO,
15
+ format='%(asctime)s - %(levelname)s - %(message)s'
16
+ )
17
+ logger = logging.getLogger(__name__)
18
+
19
+ @st.cache_resource
20
+ def load_model():
21
+ """
22
+ Load and cache the TAPEX model and tokenizer using Streamlit's caching
23
+ """
24
+ try:
25
+ tokenizer = TapexTokenizer.from_pretrained(
26
+ "microsoft/tapex-large-finetuned-wtq",
27
+ model_max_length=1024
28
+ )
29
+ model = BartForConditionalGeneration.from_pretrained(
30
+ "microsoft/tapex-large-finetuned-wtq"
31
+ )
32
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
33
+ model = model.to(device)
34
+ model.eval()
35
+ return tokenizer, model
36
+ except Exception as e:
37
+ st.error(f"Error loading model: {str(e)}")
38
+ return None, None
39
+
40
+ def parse_xml_to_dataframe(xml_string: str):
41
+ """
42
+ Parse XML string to DataFrame with error handling
43
+ """
44
+ try:
45
+ tree = ET.parse(StringIO(xml_string))
46
+ root = tree.getroot()
47
+
48
+ data = []
49
+ columns = set()
50
+
51
+ # First pass: collect all possible columns
52
+ for record in root.findall('.//record'):
53
+ columns.update(elem.tag for elem in record)
54
+
55
+ # Second pass: create data rows
56
+ for record in root.findall('.//record'):
57
+ row_data = {col: None for col in columns}
58
+ for elem in record:
59
+ row_data[elem.tag] = elem.text
60
+ data.append(row_data)
61
+
62
+ df = pd.DataFrame(data)
63
+
64
+ # Convert numeric columns (automatically detect)
65
+ for col in df.columns:
66
+ try:
67
+ df[col] = pd.to_numeric(df[col])
68
+ except:
69
+ continue
70
+
71
+ return df, None
72
+ except Exception as e:
73
+ return None, f"Error parsing XML: {str(e)}"
74
+
75
+ def process_query(tokenizer, model, df, query: str):
76
+ """
77
+ Process a single query using the TAPEX model
78
+ """
79
+ try:
80
+ start_time = time.time()
81
+
82
+ # Handle direct DataFrame operations for common queries
83
+ query_lower = query.lower()
84
+ if "highest" in query_lower or "maximum" in query_lower:
85
+ for col in df.select_dtypes(include=['number']).columns:
86
+ if col.lower() in query_lower:
87
+ return df.loc[df[col].idxmax()].to_dict()
88
+ elif "average" in query_lower or "mean" in query_lower:
89
+ for col in df.select_dtypes(include=['number']).columns:
90
+ if col.lower() in query_lower:
91
+ return f"Average {col}: {df[col].mean():.2f}"
92
+ elif "total" in query_lower or "sum" in query_lower:
93
+ for col in df.select_dtypes(include=['number']).columns:
94
+ if col.lower() in query_lower:
95
+ return f"Total {col}: {df[col].sum():.2f}"
96
+
97
+ # Use TAPEX for more complex queries
98
+ with torch.no_grad():
99
+ encoding = tokenizer(
100
+ table=df.astype(str),
101
+ query=query,
102
+ return_tensors="pt",
103
+ padding=True,
104
+ truncation=True
105
+ )
106
+ outputs = model.generate(**encoding)
107
+ answer = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
108
+
109
+ processing_time = time.time() - start_time
110
+ return f"Answer: {answer} (Processing time: {processing_time:.2f}s)"
111
+
112
+ except Exception as e:
113
+ return f"Error processing query: {str(e)}"
114
+
115
+ def main():
116
+ st.title("XML Data Query System")
117
+ st.write("Upload your XML data and ask questions about it!")
118
+
119
+ # Initialize session state for XML input and query if not exists
120
+ if 'xml_input' not in st.session_state:
121
+ st.session_state.xml_input = ""
122
+ if 'current_query' not in st.session_state:
123
+ st.session_state.current_query = ""
124
+
125
+ # Load model
126
+ with st.spinner("Loading TAPEX model... (this may take a few moments)"):
127
+ tokenizer, model = load_model()
128
+ if tokenizer is None or model is None:
129
+ st.error("Failed to load the model. Please refresh the page.")
130
+ return
131
+
132
+ # XML Input
133
+ xml_input = st.text_area(
134
+ "Enter your XML data here:",
135
+ value=st.session_state.xml_input,
136
+ height=200,
137
+ help="Paste your XML data here. Make sure it's properly formatted."
138
+ )
139
+
140
+ # Sample XML button
141
+ if st.button("Load Sample XML"):
142
+ st.session_state.xml_input = """<?xml version="1.0" encoding="UTF-8"?>
143
+ <data>
144
+ <records>
145
+ <record>
146
+ <company>Apple</company>
147
+ <revenue>365.7</revenue>
148
+ <employees>147000</employees>
149
+ <year>2021</year>
150
+ </record>
151
+ <record>
152
+ <company>Microsoft</company>
153
+ <revenue>168.1</revenue>
154
+ <employees>181000</employees>
155
+ <year>2021</year>
156
+ </record>
157
+ <record>
158
+ <company>Amazon</company>
159
+ <revenue>386.1</revenue>
160
+ <employees>1608000</employees>
161
+ <year>2021</year>
162
+ </record>
163
+ </records>
164
+ </data>"""
165
+ st.rerun()
166
+
167
+ if xml_input:
168
+ df, error = parse_xml_to_dataframe(xml_input)
169
+ if error:
170
+ st.error(error)
171
+ else:
172
+ st.success("XML parsed successfully!")
173
+
174
+ # Display DataFrame
175
+ st.subheader("Parsed Data:")
176
+ st.dataframe(df)
177
+
178
+ # Query input
179
+ query = st.text_input(
180
+ "Enter your question about the data:",
181
+ value=st.session_state.current_query,
182
+ help="Example: 'Which company has the highest revenue?'"
183
+ )
184
+
185
+ # Process query
186
+ if query:
187
+ with st.spinner("Processing query..."):
188
+ result = process_query(tokenizer, model, df, query)
189
+ st.write(result)
190
+
191
+ # Sample queries
192
+ st.subheader("Sample Questions (Click to use):")
193
+ sample_queries = [
194
+ "Which company has the highest revenue?",
195
+ "What is the average revenue of all companies?",
196
+ "How many employees does Microsoft have?",
197
+ "Which company has the most employees?",
198
+ "What is the total revenue of all companies?"
199
+ ]
200
+
201
+ # Create columns for sample query buttons
202
+ cols = st.columns(len(sample_queries))
203
+ for idx, (col, sample_query) in enumerate(zip(cols, sample_queries)):
204
+ with col:
205
+ if st.button(f"Query {idx + 1}", help=sample_query, key=f"query_btn_{idx}"):
206
+ st.session_state.current_query = sample_query
207
+ st.rerun()
208
+
209
+ # Display the sample queries as text for reference
210
+ with st.expander("View all sample questions"):
211
+ for idx, query in enumerate(sample_queries, 1):
212
+ st.write(f"{idx}. {query}")
213
+
214
+ if __name__ == "__main__":
215
+ main()