namberino commited on
Commit
c357e06
·
1 Parent(s): 98e6c10
Files changed (1) hide show
  1. app.py +142 -67
app.py CHANGED
@@ -1,86 +1,161 @@
1
- # gradio_with_testclient.py
2
  import gradio as gr
 
3
  from fastapi.testclient import TestClient
4
- import os
5
- from fastapi_app import app as fastapi_app
6
 
 
7
  client = TestClient(fastapi_app)
8
 
9
- def call_generate(file_obj, topics: str, n_questions: int, difficulty: str, qtype: str):
10
- # Prepare file tuple expected by requests-like interface:
11
  if file_obj is None:
12
- return "Please upload a PDF."
13
 
14
- # Gradio may provide a filepath (str) or a tempfile path. Handle both:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  if isinstance(file_obj, str):
16
- f = open(file_obj, "rb")
17
- filename = os.path.basename(file_obj)
18
- close_after = True
19
- else:
20
- # file_obj is a file-like object (SpooledTemporaryFile / BytesIO)
21
- f = file_obj
22
- filename = getattr(file_obj, "name", "uploaded.pdf")
23
- close_after = False
24
-
25
- files = {
26
- "file": (filename, f, "application/pdf")
27
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  data = {
30
- "topics": topics,
31
  "n_questions": str(n_questions),
32
- "difficulty": difficulty.lower(),
33
- "qtype": qtype.lower(),
34
  }
35
 
36
  try:
37
- resp = client.post("/generate/", files=files, data=data, timeout=120)
38
- finally:
39
- if close_after:
40
- f.close()
 
41
 
42
  if resp.status_code != 200:
43
- return f"Server returned {resp.status_code}: {resp.text}"
44
-
45
- j = resp.json()
46
- # Format output
47
- out_lines = []
48
- out_lines.append(f"**Topics:** {', '.join(j.get('topics', []))}")
49
- out_lines.append(f"**Avg confidence:** {j.get('avg_confidence', 0):.3f}")
50
- out_lines.append(f"**Generation time:** {j.get('generation_time', 0):.2f} sec")
51
- out_lines.append("---")
52
-
53
- generated = j.get("generated", [])
54
- if not generated:
55
- out_lines.append("No questions returned.")
56
- return "\n\n".join(out_lines)
57
-
58
- for i, mcq in enumerate(generated, start=1):
59
- out_lines.append(f"### Q{i}: {mcq.get('question')}")
60
- options = mcq.get("options", {})
61
- for k in sorted(options.keys()):
62
- out_lines.append(f"**{k}.** {options[k]}")
63
- out_lines.append(f"**Correct:** {mcq.get('correct_answer', '')} Confidence: {mcq.get('confidence_score', 0):.3f}")
64
- out_lines.append("---")
65
-
66
- return "\n\n".join(out_lines)
67
-
68
-
69
- with gr.Blocks(title="RAG MCQ (in-process via TestClient)") as demo:
70
- gr.Markdown("# RAG MCQ Generation — TestClient in-process")
71
  with gr.Row():
72
- with gr.Column(scale=2):
73
- pdf_in = gr.File(label="Upload PDF (required)", file_types=[".pdf"])
74
- topics = gr.Textbox(label="Topics (comma-separated)", placeholder="e.g. linear algebra, eigenvalues", value="")
75
- n_questions = gr.Slider(label="Questions per topic", minimum=1, maximum=10, step=1, value=1)
76
- difficulty = gr.Dropdown(label="Difficulty", choices=["easy", "medium", "hard"], value="medium")
77
- qtype = gr.Dropdown(label="Question type", choices=["definition", "application", "conceptual", "calculation"], value="definition")
78
- submit = gr.Button("Generate MCQs")
79
- with gr.Column(scale=3):
80
- output = gr.Markdown("Results will appear here")
81
-
82
- submit.click(fn=call_generate, inputs=[pdf_in, topics, n_questions, difficulty, qtype], outputs=[output])
 
 
 
83
 
84
  if __name__ == "__main__":
85
- # demo.launch(server_name="0.0.0.0", server_port=7860)
86
- demo.launch()
 
 
1
  import gradio as gr
2
+ import json
3
  from fastapi.testclient import TestClient
4
+ from fastapi_app import app as fastapi_app # import your renamed FastAPI module
5
+ import io
6
 
7
+ # In-process client for the FastAPI app
8
  client = TestClient(fastapi_app)
9
 
10
+ def read_file_input(file_obj):
 
11
  if file_obj is None:
12
+ return None
13
 
14
+ # file-like object (has read)
15
+ if hasattr(file_obj, "read"):
16
+ try:
17
+ file_obj.seek(0)
18
+ except Exception:
19
+ pass
20
+ try:
21
+ data = file_obj.read()
22
+ # If read returns str (rare), encode it
23
+ if isinstance(data, str):
24
+ return data.encode()
25
+ return data
26
+ except Exception:
27
+ # continue to other strategies
28
+ pass
29
+
30
+ # raw bytes
31
+ if isinstance(file_obj, (bytes, bytearray)):
32
+ return bytes(file_obj)
33
+
34
+ # path string (local path)
35
  if isinstance(file_obj, str):
36
+ try:
37
+ with open(file_obj, "rb") as f:
38
+ return f.read()
39
+ except Exception:
40
+ # not a path, fall through to try encoding string
41
+ return file_obj.encode()
42
+
43
+ # dict-like (old Gradio or different frontends)
44
+ try:
45
+ if isinstance(file_obj, dict):
46
+ # common keys: "name", "data"
47
+ if "data" in file_obj:
48
+ data = file_obj["data"]
49
+ if isinstance(data, (bytes, bytearray)):
50
+ return bytes(data)
51
+ if isinstance(data, str):
52
+ return data.encode()
53
+ if "name" in file_obj:
54
+ maybe_path = file_obj["name"]
55
+ if isinstance(maybe_path, str):
56
+ try:
57
+ with open(maybe_path, "rb") as f:
58
+ return f.read()
59
+ except Exception:
60
+ pass
61
+ except Exception:
62
+ pass
63
+
64
+ # Object with attributes (NamedString with .name/.value)
65
+ try:
66
+ name = getattr(file_obj, "name", None)
67
+ data = getattr(file_obj, "data", None)
68
+ value = getattr(file_obj, "value", None)
69
+ if isinstance(data, (bytes, bytearray)):
70
+ return bytes(data)
71
+ if isinstance(value, (bytes, bytearray)):
72
+ return bytes(value)
73
+ if isinstance(value, str):
74
+ return value.encode()
75
+ if isinstance(name, str):
76
+ try:
77
+ with open(name, "rb") as f:
78
+ return f.read()
79
+ except Exception:
80
+ pass
81
+ except Exception:
82
+ pass
83
+
84
+ # String representation encoded
85
+ try:
86
+ return str(file_obj).encode()
87
+ except Exception:
88
+ return None
89
+
90
+ def call_generate(file_obj, topics, n_questions, difficulty, question_type):
91
+ if file_obj is None:
92
+ return {"error": "No file uploaded."}
93
+
94
+ # Read the uploaded file bytes and create multipart payload
95
+ file_bytes = read_file_input(file_obj)
96
+ if not file_bytes:
97
+ return {"error": "Could not read uploaded file (empty or unknown format)."}
98
+ files = {"file": ("uploaded_file", file_bytes, "application/octet-stream")}
99
+
100
+ print(files)
101
 
102
  data = {
103
+ "topics": topics if topics is not None else "",
104
  "n_questions": str(n_questions),
105
+ "difficulty": difficulty if difficulty is not None else "",
106
+ "question_type": question_type if question_type is not None else ""
107
  }
108
 
109
  try:
110
+ resp = client.post("/generate/", files=files, data=data, timeout=120) # increase timeout if needed
111
+ except Exception as e:
112
+ return {"error": f"Request failed: {e}"}
113
+
114
+ print(resp.status_code)
115
 
116
  if resp.status_code != 200:
117
+ # return helpful debug info
118
+ return {
119
+ "status_code": resp.status_code,
120
+ "text": resp.text,
121
+ "json": None
122
+ }
123
+
124
+ # print(resp.text)
125
+
126
+ # Parse JSON response
127
+ try:
128
+ out = resp.json()
129
+ except Exception:
130
+ # maybe the endpoint returns text: return it directly
131
+ return {"text": resp.text}
132
+
133
+ # pretty-format the JSON for display
134
+ return out
135
+
136
+ # Gradio UI
137
+ with gr.Blocks(title="RAG MCQ generator") as gradio_app:
138
+ gr.Markdown("## Upload a file and generate MCQs")
139
+
140
+ with gr.Row():
141
+ file_input = gr.File(label="Upload file (PDF, docx, etc)", type="filepath", file_types=[".pdf"])
142
+ topics = gr.Textbox(label="Topics (comma separated)", placeholder="e.g. calculus, derivatives")
 
 
143
  with gr.Row():
144
+ n_questions = gr.Slider(minimum=1, maximum=50, step=1, value=5, label="Number of questions")
145
+ difficulty = gr.Dropdown(choices=["easy", "medium", "hard"], value="medium", label="Difficulty")
146
+ question_type = gr.Dropdown(choices=["mcq", "short", "long"], value="mcq", label="Question type")
147
+
148
+ generate_btn = gr.Button("Generate")
149
+ output = gr.JSON(label="Response")
150
+
151
+ generate_btn.click(
152
+ fn=call_generate,
153
+ inputs=[file_input, topics, n_questions, difficulty, question_type],
154
+ outputs=[output],
155
+ )
156
+
157
+ app = gradio_app
158
 
159
  if __name__ == "__main__":
160
+ # gradio_app.launch(server_name="0.0.0.0", server_port=7860, share=False)
161
+ gradio_app.launch()