sango07 commited on
Commit
acea53d
·
verified ·
1 Parent(s): 4668bbd

Create test_case_generator.py

Browse files
Files changed (1) hide show
  1. test_case_generator.py +125 -0
test_case_generator.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import pandas as pd
4
+ from langchain_openai import ChatOpenAI
5
+ from langchain_core.prompts import PromptTemplate
6
+ from langchain_community.document_loaders import PyPDFLoader
7
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
8
+
9
+ class TestCaseGenerator:
10
+ def __init__(self, api_key=None):
11
+ # Allow API key to be passed in or read from environment
12
+ if api_key:
13
+ os.environ["OPENAI_API_KEY"] = api_key
14
+
15
+ # Predefined question types
16
+ self.available_question_types = [
17
+ 'cause_and_effect_reasoning',
18
+ 'temporal_reasoning',
19
+ 'object_affordance',
20
+ 'adversarial_tasks',
21
+ 'common_sense_reasoning',
22
+ 'hallucination',
23
+ 'sycophancy'
24
+ ]
25
+
26
+ def load_and_split_document(self, doc, chunk_size=1000, chunk_overlap=100):
27
+ """Load and split the document into manageable chunks."""
28
+ # Support both file path and uploaded file
29
+ if isinstance(doc, str):
30
+ loader = PyPDFLoader(doc)
31
+ docs = loader.load()
32
+ else:
33
+ # Assume it's a BytesIO object from Streamlit upload
34
+ with open('temp_uploaded_file.pdf', 'wb') as f:
35
+ f.write(doc.getvalue())
36
+ loader = PyPDFLoader('temp_uploaded_file.pdf')
37
+ docs = loader.load()
38
+
39
+ text_splitter = RecursiveCharacterTextSplitter(
40
+ chunk_size=chunk_size,
41
+ chunk_overlap=chunk_overlap,
42
+ length_function=len,
43
+ is_separator_regex=False
44
+ )
45
+ return text_splitter.split_documents(docs)
46
+
47
+ def get_prompt_template(self, question_type):
48
+ """Get the prompt template for the given question type."""
49
+ prompts = {
50
+ "cause_and_effect_reasoning": cause_and_effect_reasoning,
51
+ "temporal_reasoning": temporal_reasoning,
52
+ "object_affordance": object_affordance,
53
+ # Add other prompts as needed
54
+ }
55
+ return prompts.get(question_type, None)
56
+
57
+ def extract_json_from_response(self, llm_response):
58
+ """Clean and extract JSON from LLM response."""
59
+ llm = ChatOpenAI(temperature=0.25, model="gpt-3.5-turbo")
60
+ clean_prompt = """
61
+ You're a highly skilled JSON validator and formatter.
62
+ Convert the following text into a valid JSON format:
63
+ {input_json}
64
+
65
+ Ensure the output follows this structure:
66
+ {{
67
+ "questions": [
68
+ {{
69
+ "id": 1,
70
+ "question": "...",
71
+ "answer": "..."
72
+ }}
73
+ ]
74
+ }}
75
+ """
76
+
77
+ prompt_template = PromptTemplate.from_template(clean_prompt)
78
+ final = prompt_template.format(input_json=llm_response)
79
+ return llm.invoke(final).content
80
+
81
+ def convert_qa_to_df(self, llm_response):
82
+ """Convert LLM response to a pandas DataFrame."""
83
+ try:
84
+ if isinstance(llm_response, str):
85
+ data = json.loads(llm_response)
86
+ else:
87
+ data = llm_response
88
+
89
+ questions_data = data.get('questions', [])
90
+ return pd.DataFrame(questions_data)[['question', 'answer']]
91
+ except Exception as e:
92
+ print(f"Error processing response: {e}")
93
+ return pd.DataFrame()
94
+
95
+ def generate_testcases(self, doc, question_type, num_testcases=10, temperature=0.7):
96
+ """Generate test cases for a specific question type."""
97
+ docs = self.load_and_split_document(doc)
98
+ model = ChatOpenAI(temperature=temperature, model="gpt-3.5-turbo")
99
+ prompt = self.get_prompt_template(question_type)
100
+
101
+ if prompt is None:
102
+ raise ValueError(f"Invalid question type: {question_type}")
103
+
104
+ prompt_template = PromptTemplate.from_template(prompt)
105
+ testset_df = pd.DataFrame(columns=['question', 'answer', 'question_type'])
106
+ question_count = 0
107
+
108
+ for doc_chunk in docs:
109
+ if question_count >= num_testcases:
110
+ break
111
+
112
+ final_formatted_prompt = prompt_template.format(context=doc_chunk.page_content)
113
+
114
+ response = model.invoke(final_formatted_prompt).content
115
+
116
+ try:
117
+ cleaned_json = self.extract_json_from_response(response)
118
+ df = self.convert_qa_to_df(cleaned_json)
119
+ df['question_type'] = question_type
120
+ testset_df = pd.concat([testset_df, df], ignore_index=True)
121
+ question_count += len(df)
122
+ except Exception as e:
123
+ print(f"Error generating questions: {e}")
124
+
125
+ return testset_df.head(num_testcases)