Saiteja Solleti commited on
Commit
95bfa0d
·
1 Parent(s): 34b540a

adding load dataset func

Browse files
Files changed (3) hide show
  1. app.py +5 -7
  2. loaddataset.py +90 -0
  3. logger.py +5 -0
app.py CHANGED
@@ -1,22 +1,20 @@
1
  import gradio as gr
2
  import os
3
- import pandas as pd
4
 
5
- from typing import Dict, List, Optional
6
  from model import generate_response
7
- from datasets import load_dataset
8
  from huggingface_hub import login
9
  from huggingface_hub import whoami
10
  from huggingface_hub import dataset_info
11
 
12
 
13
- DATASET_CONFIGS = [
14
- 'covidqa', 'cuad', 'delucionqa', 'emanual', 'expertqa', 'finqa', 'hagrid', 'hotpotqa', 'msmarco', 'pubmedqa','tatqa', 'techqa'
15
- ]
16
-
17
  hf_token = os.getenv("HF_TOKEN")
18
  login(hf_token)
19
 
 
 
 
 
20
  def chatbot(prompt):
21
  return whoami()
22
 
 
1
  import gradio as gr
2
  import os
 
3
 
4
+ from loaddataset import ExtractRagBenchData
5
  from model import generate_response
 
6
  from huggingface_hub import login
7
  from huggingface_hub import whoami
8
  from huggingface_hub import dataset_info
9
 
10
 
 
 
 
 
11
  hf_token = os.getenv("HF_TOKEN")
12
  login(hf_token)
13
 
14
+ rag_extracted_data = ExtractRagBenchData()
15
+
16
+ rag_extracted_data.head(5)
17
+
18
  def chatbot(prompt):
19
  return whoami()
20
 
loaddataset.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from datasets import load_dataset
3
+ from logger import logger
4
+ from typing import Dict, List, Optional
5
+
6
+
7
+ DATASET_CONFIGS = [
8
+ 'covidqa', 'cuad', 'delucionqa', 'emanual', 'expertqa', 'finqa', 'hagrid', 'hotpotqa', 'msmarco', 'pubmedqa','tatqa', 'techqa'
9
+ ]
10
+
11
+ #function to load the dataset for the given configurations.
12
+ #Args:configs: List of dataset configurations to load.
13
+ #Returns: A dictionary where keys are config names and values are the loaded datasets.
14
+ def load_rag_bench_dataset(configs: List[str]) -> Dict[str, dict]:
15
+
16
+ ragbench = {}
17
+ for config in configs:
18
+ try:
19
+ ragbench[config] = load_dataset("rungalileo/ragbench", config)
20
+ logger.info(f"Successfully loaded dataset for config: {config}")
21
+ except Exception as e:
22
+ logger.error(f"Failed to load dataset for config {config}: {e}")
23
+ return ragbench
24
+
25
+
26
+ #Extract data from the RAGBench dataset and store it in a Pandas DataFrame.
27
+ #Args:ragbench: Dictionary containing loaded datasets. split: Dataset split to extract (e.g., "train", "test", "validation").
28
+ #Returns:A Pandas DataFrame containing the extracted data.
29
+ def ExtractData(ragbench: Dict[str, dict], split: str = "train") -> pd.DataFrame:
30
+
31
+ # Initialize a dictionary to store extracted data
32
+ data = {
33
+ "question": [],
34
+ "documents": [],
35
+ "gpt3_context_relevance": [],
36
+ "gpt35_utilization": [],
37
+ "gpt3_adherence": [],
38
+ "id": [],
39
+ "dataset_name": [],
40
+ "relevance_score": [],
41
+ "utilization_score": [],
42
+ "completeness_score": [],
43
+ "adherence_score": []
44
+
45
+ }
46
+
47
+ for datasetname, dataset in ragbench.items():
48
+ try:
49
+ # Ensure the split exists in the dataset
50
+ if split not in dataset:
51
+ logger.warning(f"Split '{split}' not found in dataset {datasetname}. Skipping.")
52
+ continue
53
+
54
+ # Extract data from the specified split
55
+ split_data = dataset[split]
56
+
57
+ # Check if required columns exist
58
+ required_columns = ["question", "documents", "gpt3_context_relevance",
59
+ "gpt35_utilization", "gpt3_adherence", "id", "dataset_name"]
60
+ missing_columns = [col for col in required_columns if col not in split_data.column_names]
61
+ if missing_columns:
62
+ logger.warning(f"Missing columns {missing_columns} in dataset {datasetname}. Skipping.")
63
+ continue
64
+
65
+ # Append data to lists
66
+ data["question"].extend(split_data["question"])
67
+ data["documents"].extend(split_data["documents"])
68
+ data["gpt3_context_relevance"].extend(split_data["gpt3_context_relevance"])
69
+ data["gpt35_utilization"].extend(split_data["gpt35_utilization"])
70
+ data["gpt3_adherence"].extend(split_data["gpt3_adherence"])
71
+ data["id"].extend(split_data["id"])
72
+ data["dataset_name"].extend(split_data["dataset_name"])
73
+ data["relevance_score"].extend(split_data["relevance_score"])
74
+ data["utilization_score"].extend(split_data["utilization_score"])
75
+ data["completeness_score"].extend(split_data["completeness_score"])
76
+ data["adherence_score"].extend(split_data["adherence_score"])
77
+
78
+ logger.info(f"Successfully extracted data from {datasetname} ({split} split).")
79
+ except Exception as e:
80
+ logger.error(f"Error extracting data from {datasetname} ({split} split): {e}")
81
+
82
+ # Convert the dictionary to a Pandas DataFrame
83
+ df = pd.DataFrame(data)
84
+ return df
85
+
86
+ def ExtractRagBenchData():
87
+ ragbench = load_rag_bench_dataset(DATASET_CONFIGS)
88
+ rag_extracted_data = ExtractData(ragbench, split="train")
89
+
90
+ return rag_extracted_data
logger.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import logging
2
+
3
+ # Set up logging
4
+ logging.basicConfig(level=logging.INFO)
5
+ logger = logging.getLogger(__name__)