AruniAnkur commited on
Commit
699f7b6
·
verified ·
1 Parent(s): 4c4d5fa

Upload functionbloom.py

Browse files
Files changed (1) hide show
  1. functionbloom.py +388 -0
functionbloom.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Dict
2
+ import streamlit as st
3
+ import requests
4
+ import json
5
+ import fitz # PyMuPDF
6
+ from fpdf import FPDF
7
+ import os
8
+ import tempfile
9
+ from dotenv import load_dotenv
10
+ import torch
11
+ from transformers import DistilBertForSequenceClassification, DistilBertTokenizer
12
+ from torch.nn.functional import softmax
13
+ from doctr.models import ocr_predictor
14
+ from doctr.io import DocumentFile
15
+ import tempfile
16
+
17
+ load_dotenv()
18
+
19
+ model = DistilBertForSequenceClassification.from_pretrained('./fine_tuned_distilbert')
20
+ tokenizer = DistilBertTokenizer.from_pretrained('./fine_tuned_distilbert')
21
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
+ model.to(device)
23
+ mapping = {"Remembering": 0, "Understanding": 1, "Applying": 2, "Analyzing": 3, "Evaluating": 4, "Creating": 5}
24
+ reverse_mapping = {v: k for k, v in mapping.items()}
25
+ modelocr = ocr_predictor(det_arch='db_resnet50', reco_arch='crnn_vgg16_bn', pretrained=True)
26
+
27
+ def save_uploaded_file(uploaded_file):
28
+ if uploaded_file is not None:
29
+ file_extension = uploaded_file.name.split('.')[-1].lower()
30
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix = f'.{file_extension}')
31
+ temp_file.write(uploaded_file.getvalue())
32
+ temp_file.close()
33
+ return temp_file.name
34
+ return None
35
+
36
+ # Previous functions from Question Generator
37
+ def get_pdf_path(pdf_source=None, uploaded_file=None):
38
+ try:
39
+ # If a file is uploaded locally
40
+ if uploaded_file is not None:
41
+ # Create a temporary file to save the uploaded PDF
42
+ temp_dir = tempfile.mkdtemp()
43
+ pdf_path = os.path.join(temp_dir, uploaded_file.name)
44
+
45
+ # Save the uploaded file
46
+ with open(pdf_path, "wb") as pdf_file:
47
+ pdf_file.write(uploaded_file.getvalue())
48
+ return pdf_path
49
+
50
+ # If a URL is provided
51
+ if pdf_source:
52
+ response = requests.get(pdf_source, timeout=30)
53
+ response.raise_for_status()
54
+
55
+ # Create a temporary file
56
+ temp_dir = tempfile.mkdtemp()
57
+ pdf_path = os.path.join(temp_dir, "downloaded.pdf")
58
+
59
+ with open(pdf_path, "wb") as pdf_file:
60
+ pdf_file.write(response.content)
61
+ return pdf_path
62
+
63
+ # If no source is provided
64
+ st.error("No PDF source provided.")
65
+ return None
66
+ except Exception as e:
67
+ st.error(f"Error getting PDF: {e}")
68
+ return None
69
+
70
+
71
+ def extract_text_pymupdf(pdf_path):
72
+ try:
73
+ doc = fitz.open(pdf_path)
74
+ pages_content = []
75
+ for page_num in range(len(doc)):
76
+ page = doc[page_num]
77
+ pages_content.append(page.get_text())
78
+ doc.close()
79
+ return " ".join(pages_content) # Join all pages into one large context string
80
+ except Exception as e:
81
+ st.error(f"Error extracting text from PDF: {e}")
82
+ return ""
83
+
84
+
85
+ def get_bloom_taxonomy_scores(question: str) -> Dict[str, float]:
86
+ # Default scores in case of API failure
87
+ default_scores = {
88
+ "Remembering": 0.2,
89
+ "Understanding": 0.2,
90
+ "Applying": 0.15,
91
+ "Analyzing": 0.15,
92
+ "Evaluating": 0.15,
93
+ "Creating": 0.15
94
+ }
95
+
96
+ try:
97
+ scores = predict_with_loaded_model(question)
98
+ for key, value in scores.items():
99
+ if not (0 <= value <= 1):
100
+ st.warning(f"Invalid score value for {key}. Using default scores.")
101
+ return default_scores
102
+ return scores
103
+
104
+ except Exception as e:
105
+ st.warning(f"Unexpected error: {e}. Using default scores.")
106
+ return default_scores
107
+
108
+
109
+ def generate_ai_response(api_key, assistant_context, user_query, role_description, response_instructions, bloom_taxonomy_weights, num_questions, question_length, include_numericals):
110
+ try:
111
+ url = f"https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash-latest:generateContent?key={api_key}"
112
+
113
+ # Define length guidelines
114
+ length_guidelines = {
115
+ "Short": "Keep questions concise, around 10-15 words each.",
116
+ "Medium": "Create moderately detailed questions, around 20-25 words each.",
117
+ "Long": "Generate detailed, comprehensive questions, around 30-40 words each that may include multiple parts."
118
+ }
119
+
120
+ prompt = f"""
121
+ You are a highly knowledgeable assistant. Your task is to assist the user with the following context from an academic paper.
122
+
123
+ **Role**: {role_description}
124
+
125
+ **Context**: {assistant_context}
126
+
127
+ **Instructions**: {response_instructions}
128
+ Question Length Requirement: {length_guidelines[question_length]}
129
+
130
+ **Bloom's Taxonomy Weights**:
131
+ Knowledge: {bloom_taxonomy_weights['Knowledge']}%
132
+ Comprehension: {bloom_taxonomy_weights['Comprehension']}%
133
+ Application: {bloom_taxonomy_weights['Application']}%
134
+ Analysis: {bloom_taxonomy_weights['Analysis']}%
135
+ Synthesis: {bloom_taxonomy_weights['Synthesis']}%
136
+ Evaluation: {bloom_taxonomy_weights['Evaluation']}%
137
+
138
+ **Query**: {user_query}
139
+
140
+ **Number of Questions**: {num_questions}
141
+
142
+ **Include Numericals**: {include_numericals}
143
+ """
144
+
145
+ payload = {
146
+ "contents": [
147
+ {
148
+ "parts": [
149
+ {"text": prompt}
150
+ ]
151
+ }
152
+ ]
153
+ }
154
+ headers = {"Content-Type": "application/json"}
155
+
156
+ response = requests.post(url, headers=headers, data=json.dumps(payload), timeout=60)
157
+ response.raise_for_status()
158
+
159
+ result = response.json()
160
+ questions = result.get("candidates", [{}])[0].get("content", {}).get("parts", [{}])[0].get("text", "")
161
+ questions_list = [question.strip() for question in questions.split("\n") if question.strip()]
162
+
163
+ # Get Bloom's taxonomy scores for each question with progress bar
164
+ questions_with_scores = []
165
+ progress_bar = st.progress(0)
166
+ for idx, question in enumerate(questions_list):
167
+ scores = get_bloom_taxonomy_scores(question)
168
+ if scores: # Only add questions that got valid scores
169
+ questions_with_scores.append((question, scores))
170
+ progress_bar.progress((idx + 1) / len(questions_list))
171
+
172
+ if not questions_with_scores:
173
+ st.warning("Could not get Bloom's Taxonomy scores for any questions. Using default scores.")
174
+ # Use default scores if no scores were obtained
175
+ questions_with_scores = [(q, get_bloom_taxonomy_scores("")) for q in questions_list]
176
+
177
+ # Update session state with scores
178
+ st.session_state.question_scores = {q: s for q, s in questions_with_scores}
179
+
180
+ # Return just the questions
181
+ return [q for q, _ in questions_with_scores]
182
+ except requests.RequestException as e:
183
+ st.error(f"API request error: {e}")
184
+ return []
185
+ except Exception as e:
186
+ st.error(f"Error generating questions: {e}")
187
+ return []
188
+
189
+ def normalize_bloom_weights(bloom_weights):
190
+ total = sum(bloom_weights.values())
191
+ if total != 100:
192
+ normalization_factor = 100 / total
193
+ # Normalize each weight by multiplying it by the normalization factor
194
+ bloom_weights = {key: round(value * normalization_factor, 2) for key, value in bloom_weights.items()}
195
+ return bloom_weights
196
+
197
+ def generate_pdf(questions, filename="questions.pdf"):
198
+ try:
199
+ pdf = FPDF()
200
+ pdf.set_auto_page_break(auto=True, margin=15)
201
+ pdf.add_page()
202
+
203
+ # Set font
204
+ pdf.set_font("Arial", size=12)
205
+
206
+ # Add a title or heading
207
+ pdf.cell(200, 10, txt="Generated Questions", ln=True, align="C")
208
+
209
+ # Add space between title and questions
210
+ pdf.ln(10)
211
+
212
+ # Loop through questions and add them to the PDF
213
+ for i, question in enumerate(questions, 1):
214
+ # Using multi_cell for wrapping the text in case it's too long
215
+ pdf.multi_cell(0, 10, f"Q{i}: {question}")
216
+
217
+ # Save the generated PDF to the file
218
+ pdf.output(filename)
219
+ return filename
220
+ except Exception as e:
221
+ st.error(f"Error generating PDF: {e}")
222
+ return None
223
+
224
+ def process_pdf_and_generate_questions(pdf_source, uploaded_file, api_key, role_description, response_instructions, bloom_taxonomy_weights, num_questions, question_length, include_numericals):
225
+ try:
226
+
227
+ pdf_path = get_pdf_path(pdf_source, uploaded_file)
228
+ if not pdf_path:
229
+ return []
230
+
231
+ # Extract text
232
+ pdf_text = extract_text_pymupdf(pdf_path)
233
+ if not pdf_text:
234
+ return []
235
+ # Generate questions
236
+ assistant_context = pdf_text
237
+ user_query = "Generate questions based on the above context."
238
+ normalized_bloom_weights = normalize_bloom_weights(bloom_taxonomy_weights)
239
+ questions = generate_ai_response(
240
+ api_key,
241
+ assistant_context,
242
+ user_query,
243
+ role_description,
244
+ response_instructions,
245
+ normalized_bloom_weights,
246
+ num_questions,
247
+ question_length,
248
+ include_numericals
249
+ )
250
+
251
+ # Clean up temporary PDF file
252
+ try:
253
+ os.remove(pdf_path)
254
+ # Remove the temporary directory
255
+ os.rmdir(os.path.dirname(pdf_path))
256
+ except Exception as e:
257
+ st.warning(f"Could not delete temporary PDF file: {e}")
258
+
259
+ return questions
260
+ except Exception as e:
261
+ st.error(f"Error processing PDF and generating questions: {e}")
262
+ return []
263
+
264
+ def get_bloom_taxonomy_details(question_scores: Optional[Dict[str, float]] = None) -> str:
265
+ """
266
+ Generate a detailed explanation of Bloom's Taxonomy scores.
267
+ Handles missing or invalid scores gracefully.
268
+ """
269
+ try:
270
+ if question_scores is None or not isinstance(question_scores, dict):
271
+ return "Bloom's Taxonomy scores not available"
272
+
273
+ # Validate scores
274
+ valid_categories = {"Remembering", "Understanding", "Applying",
275
+ "Analyzing", "Evaluating", "Creating"}
276
+
277
+ if not all(isinstance(score, (int, float)) for score in question_scores.values()):
278
+ return "Invalid score values detected"
279
+
280
+ if not all(category in valid_categories for category in question_scores.keys()):
281
+ return "Invalid score categories detected"
282
+
283
+ details_text = "Bloom's Taxonomy Analysis:\n\n"
284
+
285
+ try:
286
+ # Sort scores by value in descending order
287
+ sorted_scores = sorted(question_scores.items(), key=lambda x: x[1], reverse=True)
288
+
289
+ # Format each score as a percentage
290
+ for category, score in sorted_scores:
291
+ percentage = min(max(score * 100, 0), 100) # Ensure percentage is between 0 and 100
292
+ details_text += f"{category}: {percentage:.1f}%\n"
293
+
294
+ # Add the predicted level
295
+ predicted_level = max(question_scores.items(), key=lambda x: x[1])[0]
296
+ details_text += f"\nPredicted Level: {predicted_level}"
297
+
298
+ return details_text.strip()
299
+
300
+ except Exception as e:
301
+ return f"Error processing scores: {str(e)}"
302
+
303
+ except Exception as e:
304
+ return f"Error generating taxonomy details: {str(e)}"
305
+
306
+
307
+ def predict_with_loaded_model(text):
308
+ inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=512)
309
+ input_ids = inputs['input_ids'].to(device)
310
+ model.eval()
311
+ with torch.no_grad():
312
+ outputs = model(input_ids)
313
+ logits = outputs.logits
314
+ probabilities = softmax(logits, dim=-1)
315
+ probabilities = probabilities.squeeze().cpu().numpy()
316
+ # Convert to float and format to 3 decimal places
317
+ class_probabilities = {reverse_mapping[i]: float(f"{prob:.3f}") for i, prob in enumerate(probabilities)}
318
+ return class_probabilities
319
+
320
+ def process_document(input_path):
321
+ if input_path.lower().endswith(".pdf"):
322
+ doc = DocumentFile.from_pdf(input_path)
323
+ #print(f"Number of pages: {len(doc)}")
324
+ elif input_path.lower().endswith((".jpg", ".jpeg", ".png", ".bmp", ".tiff")):
325
+ doc = DocumentFile.from_images(input_path)
326
+ else:
327
+ raise ValueError("Unsupported file type. Please provide a PDF or an image file.")
328
+ result = modelocr(doc)
329
+ def calculate_average_confidence(result):
330
+ total_confidence = 0
331
+ word_count = 0
332
+ for page in result.pages:
333
+ for block in page.blocks:
334
+ for line in block.lines:
335
+ for word in line.words:
336
+ total_confidence += word.confidence
337
+ word_count += 1
338
+ average_confidence = total_confidence / word_count if word_count > 0 else 0
339
+ return average_confidence
340
+ average_confidence = calculate_average_confidence(result)
341
+ string_result = result.render()
342
+ return {'Avg_Confidence': average_confidence, 'String':string_result.split('\n')}
343
+
344
+ def sendtogemini(inputpath, question):
345
+ if inputpath and inputpath.lower().endswith((".pdf", ".jpg", ".jpeg", ".png")):
346
+ qw = process_document(inputpath)
347
+ elif question:
348
+ qw = {'String': [question]}
349
+ else:
350
+ raise ValueError("Unsupported file type. Please provide a PDF or an image file.")
351
+ questionset = str(qw['String'])
352
+ # send this prompt to gemini :
353
+ questionset += """You are given a list of text fragments containing questions fragments extracted by an ocr model. Your task is to:
354
+ # only Merge the question fragments into complete and coherent questions.Don't answer then.
355
+ # Separate each question , start a new question with @ to make them easily distinguishable for further processing."""
356
+ url = f"https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash-latest:generateContent?key={os.getenv('GEMINI_API_KEY')}"
357
+
358
+ payload = {
359
+ "contents": [
360
+ {
361
+ "parts": [
362
+ {"text": questionset}
363
+ ]
364
+ }
365
+ ]
366
+ }
367
+ headers = {"Content-Type": "application/json"}
368
+
369
+ response = requests.post(url, headers=headers, data=json.dumps(payload), timeout=60)
370
+ result = response.json()
371
+ res1 = result.get("candidates", [{}])[0].get("content", {}).get("parts", [{}])[0].get("text", "")
372
+ question = []
373
+ for i in res1.split('\n'):
374
+ i = i.strip()
375
+ if len(i) > 0:
376
+ if i[0] == '@':
377
+ i = i[1:].strip().lower()
378
+ if i[0] == 'q':
379
+ question.append(i[1:].strip())
380
+ else:
381
+ question.append(i)
382
+ data = []
383
+ for i in question:
384
+ d = {}
385
+ d['question'] = i
386
+ d['score'] = predict_with_loaded_model(i)
387
+ data.append(d)
388
+ return data