michaelmc1618 commited on
Commit
a6b2b74
·
verified ·
1 Parent(s): 07a3b9b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +142 -197
app.py CHANGED
@@ -1,33 +1,25 @@
1
  import os
2
- import tempfile
3
- import torch
4
- import yt_dlp as youtube_dl
 
 
 
5
  import gradio as gr
6
- from transformers import pipeline, AutoTokenizer, AutoModelForMaskedLM, AutoProcessor, AutoModelForSpeechSeq2Seq
7
  from huggingface_hub import InferenceClient
 
8
  from datasets import load_dataset
9
  import fitz # PyMuPDF
10
- from transformers.pipelines.audio_utils import ffmpeg_read
11
-
12
- # Constants for Whisper ASR
13
- MODEL_NAME = "openai/whisper-large-v3"
14
- BATCH_SIZE = 8
15
- FILE_LIMIT_MB = 1000
16
- YT_LENGTH_LIMIT_S = 3600 # limit to 1 hour YouTube files
17
-
18
- device = 0 if torch.cuda.is_available() else "cpu"
19
 
20
- # Load the Whisper model and processor
21
- processor = AutoProcessor.from_pretrained(MODEL_NAME)
22
- model_s2s = AutoModelForSpeechSeq2Seq.from_pretrained(MODEL_NAME)
23
 
24
- # Load the BERT model and tokenizer
25
- tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
26
- model = AutoModelForMaskedLM.from_pretrained("google-bert/bert-base-uncased")
27
-
28
- # Create the fill-mask pipeline
29
- pipe = pipeline("fill-mask", model=model, tokenizer=tokenizer)
30
 
 
31
  client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
32
 
33
  def respond(
@@ -48,123 +40,49 @@ def respond(
48
 
49
  messages.append({"role": "user", "content": message})
50
 
51
- try:
52
- response = ""
53
- for message in client.chat_completion(
54
- messages,
55
- max_tokens=max_tokens,
56
- stream=True,
57
- temperature=temperature,
58
- top_p=top_p,
59
- ):
60
- token = message.choices[0].delta.content
61
- if token is not None:
62
- response += token
63
- yield response, history + [(message, response)]
64
- except Exception as e:
65
- print(f"Error during chat completion: {e}")
66
- yield "An error occurred during the chat completion.", history
67
 
68
  def generate_case_outcome(prosecutor_response, defense_response):
69
- prompt = f"Prosecutor's arguments: {prosecutor_response}\n\nDefense's arguments: {defense_response}\n\nProvide details on who won the case and why. Provide reasons for your decision and provide a link to the source of the case."
70
  evaluation = ""
71
- try:
72
- for message in client.chat_completion(
73
- [{"role": "system", "content": "You are a legal expert evaluating the details of the case presented by the prosecution and the defense."},
74
- {"role": "user", "content": prompt}],
75
- max_tokens=512,
76
- stream=True,
77
- temperature=0.6,
78
- top_p=0.95,
79
- ):
80
- token = message.choices[0].delta.content
81
- if token is not None:
82
- evaluation += token
83
- except Exception as e:
84
- print(f"Error during case outcome generation: {e}")
85
- return "An error occurred during the case outcome generation."
86
  return evaluation
87
 
88
- def determine_outcome(outcome):
89
- prosecutor_count = outcome.split().count("Prosecutor")
90
- defense_count = outcome.split().count("Defense")
91
- if prosecutor_count > defense_count:
 
 
 
92
  return "Prosecutor Wins"
93
- elif defense_count > prosecutor_count:
94
  return "Defense Wins"
95
  else:
96
  return "No clear winner"
97
 
98
- def transcribe(inputs, task):
99
- if inputs is None:
100
- raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.")
101
-
102
- inputs = processor(inputs, return_tensors="pt", sampling_rate=16000).to(device)
103
- with torch.no_grad():
104
- generated_ids = model_s2s.generate(inputs["input_features"])
105
- transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
106
-
107
- return transcription
108
-
109
- def _return_yt_html_embed(yt_url):
110
- video_id = yt_url.split("?v=")[-1]
111
- HTML_str = (
112
- f'<center> <iframe width="500" height="320" src="https://www.youtube.com/embed/{video_id}"> </iframe>'
113
- " </center>"
114
- )
115
- return HTML_str
116
-
117
- def download_yt_audio(yt_url, filename):
118
- info_loader = youtube_dl.YoutubeDL()
119
-
120
- try:
121
- info = info_loader.extract_info(yt_url, download=False)
122
- except youtube_dl.utils.DownloadError as err:
123
- raise gr.Error(str(err))
124
-
125
- file_length = info["duration_string"]
126
- file_h_m_s = file_length.split(":")
127
- file_h_m_s = [int(sub_length) for sub_length in file_h_m_s]
128
-
129
- if len(file_h_m_s) == 1:
130
- file_h_m_s.insert(0, 0)
131
- if len(file_h_m_s) == 2:
132
- file_h_m_s.insert(0, 0)
133
- file_length_s = file_h_m_s[0] * 3600 + file_h_m_s[1] * 60 + file_h_m_s[2]
134
-
135
- if file_length_s > YT_LENGTH_LIMIT_S:
136
- yt_length_limit_hms = time.strftime("%HH:%MM:%SS", time.gmtime(YT_LENGTH_LIMIT_S))
137
- file_length_hms = time.strftime("%HH:%MM:%SS", time.gmtime(file_length_s))
138
- raise gr.Error(f"Maximum YouTube length is {yt_length_limit_hms}, got {file_length_hms} YouTube video.")
139
-
140
- ydl_opts = {"outtmpl": filename, "format": "worstvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best"}
141
-
142
- with youtube_dl.YoutubeDL(ydl_opts) as ydl:
143
- try:
144
- ydl.download([yt_url])
145
- except youtube_dl.utils.ExtractorError as err:
146
- raise gr.Error(str(err))
147
-
148
- def yt_transcribe(yt_url, task, max_filesize=75.0):
149
- html_embed_str = _return_yt_html_embed(yt_url)
150
-
151
- with tempfile.TemporaryDirectory() as tmpdirname:
152
- filepath = os.path.join(tmpdirname, "video.mp4")
153
- download_yt_audio(yt_url, filepath)
154
- with open(filepath, "rb") as f:
155
- inputs = f.read()
156
-
157
- inputs = ffmpeg_read(inputs, processor.feature_extractor.sampling_rate)
158
- inputs = {"array": inputs, "sampling_rate": processor.feature_extractor.sampling_rate}
159
-
160
- inputs = processor(inputs, return_tensors="pt", sampling_rate=16000).to(device)
161
- with torch.no_grad():
162
- generated_ids = model_s2s.generate(inputs["input_features"])
163
- transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
164
-
165
- return html_embed_str, transcription
166
-
167
- # Custom CSS for white background and black text for input and output boxes
168
  custom_css = """
169
  body {
170
  background-color: #ffffff;
@@ -253,17 +171,63 @@ def chat_between_bots(system_message1, system_message2, max_tokens, temperature,
253
  response2 = response2[:max_length]
254
 
255
  outcome = generate_case_outcome(response1, response2)
256
- winner = determine_outcome(outcome)
257
 
258
- return response1, response2, history1, history2, shared_history, outcome
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
 
260
  def get_top_10_cases():
261
- prompt = "List 10 high-profile legal cases that have received significant media attention and are currently ongoing. Just a list of case names and numbers."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
262
  response = ""
263
  for message in client.chat_completion(
264
- [{"role": "system", "content": "You are a legal research expert, able to provide information about high-profile legal cases."},
265
- {"role": "user", "content": prompt}],
266
- max_tokens=512,
267
  stream=True,
268
  temperature=0.6,
269
  top_p=0.95,
@@ -271,14 +235,8 @@ def get_top_10_cases():
271
  token = message.choices[0].delta.content
272
  if token is not None:
273
  response += token
274
- return response
275
-
276
- def add_message(history, message):
277
- for x in message["files"]:
278
- history.append(((x,), None))
279
- if message["text"] is not None:
280
- history.append((message["text"], None))
281
- return history, gr.MultimodalTextbox(value=None, interactive=True)
282
 
283
  def print_like_dislike(x: gr.LikeData):
284
  print(x.index, x.value, x.liked)
@@ -290,32 +248,22 @@ def save_conversation(history1, history2, shared_history):
290
  return history1, history2, shared_history
291
 
292
  def ask_about_case_outcome(shared_history, question):
293
- prompt = f"Case Outcome: {shared_history}\n\nQuestion: {question}\n\nAnswer:"
294
- response = ""
295
- for message in client.chat_completion(
296
- [{"role": "system", "content": "You are a legal expert answering questions based on the case outcome provided."},
297
- {"role": "user", "content": prompt}],
298
- max_tokens=512,
299
- stream=True,
300
- temperature=0.6,
301
- top_p=0.95,
302
- ):
303
- token = message.choices[0].delta.content
304
- if token is not None:
305
- response += token
306
- return response
307
 
308
  with gr.Blocks(css=custom_css) as demo:
309
  history1 = gr.State([])
310
  history2 = gr.State([])
311
  shared_history = gr.State([])
 
 
312
  top_10_cases = gr.State("")
313
 
314
  with gr.Tab("Argument Evaluation"):
315
  with gr.Row():
316
  with gr.Column(scale=1):
317
  top_10_btn = gr.Button("Give me the top 10 cases")
318
- top_10_output = gr.Textbox(label="Top 10 Cases", interactive=False, elem_classes=["scroll-box"])
319
  top_10_btn.click(get_top_10_cases, outputs=top_10_output)
320
  with gr.Column(scale=2):
321
  message = gr.Textbox(label="Case to Argue")
@@ -336,56 +284,53 @@ with gr.Blocks(css=custom_css) as demo:
336
  with gr.Column(scale=1):
337
  defense_score_color = gr.HTML()
338
 
339
- outcome = gr.Textbox(label="Outcome", interactive=False, elem_classes=["scroll-box"])
 
340
 
341
  with gr.Row():
342
  submit_btn = gr.Button("Argue")
343
  clear_btn = gr.Button("Clear and Reset")
344
  save_btn = gr.Button("Save Conversation")
345
 
346
- submit_btn.click(chat_between_bots, inputs=[system_message1, system_message2, max_tokens, temperature, top_p, history1, history2, shared_history, message], outputs=[prosecutor_response, defense_response, history1, history2, shared_history, outcome])
347
- clear_btn.click(reset_conversation, outputs=[history1, history2, shared_history, prosecutor_response, defense_response, outcome])
348
  save_btn.click(save_conversation, inputs=[history1, history2, shared_history], outputs=[history1, history2, shared_history])
349
-
350
- with gr.Tab("Practice Arguments"):
351
- mf_transcribe = gr.Interface(
352
- fn=transcribe,
353
- inputs=[
354
- gr.Audio(type="filepath", label="Record or Upload Audio"),
355
- gr.Radio(["transcribe", "translate"], label="Task", value="transcribe"),
356
- ],
357
- outputs="text",
358
- layout="horizontal",
359
- title="Practice Legal Arguments - Microphone",
360
- description=(
361
- "Practice your legal arguments by recording them through your microphone or uploading an audio file. The arguments will be transcribed for review."
362
- ),
363
- allow_flagging="never",
364
- )
365
-
366
- yt_transcribe = gr.Interface(
367
- fn=yt_transcribe,
368
- inputs=[
369
- gr.Textbox(lines=1, placeholder="Paste the URL to a YouTube video here", label="YouTube URL"),
370
- gr.Radio(["transcribe", "translate"], label="Task", value="transcribe")
371
- ],
372
- outputs=["html", "text"],
373
- layout="horizontal",
374
- title="Practice Legal Arguments - YouTube",
375
- description=(
376
- "Practice your legal arguments by providing a YouTube video link. The arguments will be transcribed for review."
377
- ),
378
- allow_flagging="never",
379
  )
380
 
381
- gr.TabbedInterface([mf_transcribe, yt_transcribe], ["Microphone", "YouTube"])
382
 
383
- with gr.Tab("Case Outcome Chat"):
384
- case_question = gr.Textbox(label="Ask a Question about the Case Outcome")
385
- case_answer = gr.Textbox(label="Answer", interactive=False, elem_classes=["scroll-box"])
386
- ask_case_btn = gr.Button("Ask")
387
 
388
- ask_case_btn.click(ask_about_case_outcome, inputs=[shared_history, case_question], outputs=case_answer)
389
 
390
  demo.queue()
391
  demo.launch()
 
1
  import os
2
+ os.system('pip install transformers')
3
+ os.system('pip install datasets')
4
+ os.system('pip install gradio')
5
+ os.system('pip install minijinja')
6
+ os.system('pip install PyMuPDF')
7
+
8
  import gradio as gr
 
9
  from huggingface_hub import InferenceClient
10
+ from transformers import pipeline
11
  from datasets import load_dataset
12
  import fitz # PyMuPDF
 
 
 
 
 
 
 
 
 
13
 
14
+ # Load dataset
15
+ dataset = load_dataset("ibunescu/qa_legal_dataset_train")
 
16
 
17
+ # Different pipelines for different tasks
18
+ qa_pipeline = pipeline("question-answering", model="deepset/roberta-base-squad2")
19
+ summarization_pipeline = pipeline("summarization", model="facebook/bart-large-cnn")
20
+ mask_filling_pipeline = pipeline("fill-mask", model="nlpaueb/legal-bert-base-uncased")
 
 
21
 
22
+ # Inference client for chat completion
23
  client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
24
 
25
  def respond(
 
40
 
41
  messages.append({"role": "user", "content": message})
42
 
43
+ response = ""
44
+ for message in client.chat_completion(
45
+ messages,
46
+ max_tokens=max_tokens,
47
+ stream=True,
48
+ temperature=temperature,
49
+ top_p=top_p,
50
+ ):
51
+ token = message.choices[0].delta.content
52
+ if token is not None:
53
+ response += token
54
+ yield response, history + [(message, response)]
 
 
 
 
55
 
56
  def generate_case_outcome(prosecutor_response, defense_response):
57
+ prompt = f"Prosecutor's Argument: {prosecutor_response}\nDefense Attorney's Argument: {defense_response}\n\nEvaluate both arguments, point out the strengths and weaknesses, and determine who won the case. Provide reasons for your decision."
58
  evaluation = ""
59
+ for message in client.chat_completion(
60
+ [{"role": "system", "content": "You are a legal expert evaluating the arguments presented by the prosecution and the defense."},
61
+ {"role": "user", "content": prompt}],
62
+ max_tokens=512,
63
+ stream=True,
64
+ temperature=0.6,
65
+ top_p=0.95,
66
+ ):
67
+ token = message.choices[0].delta.content
68
+ if token is not None:
69
+ evaluation += token
 
 
 
 
70
  return evaluation
71
 
72
+ def determine_winner(outcome):
73
+ if "Prosecutor" in outcome and "Defense" in outcome:
74
+ if outcome.count("Prosecutor") > outcome.count("Defense"):
75
+ return "Prosecutor Wins"
76
+ else:
77
+ return "Defense Wins"
78
+ elif "Prosecutor" in outcome:
79
  return "Prosecutor Wins"
80
+ elif "Defense" in outcome:
81
  return "Defense Wins"
82
  else:
83
  return "No clear winner"
84
 
85
+ # Custom CSS for a clean layout
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  custom_css = """
87
  body {
88
  background-color: #ffffff;
 
171
  response2 = response2[:max_length]
172
 
173
  outcome = generate_case_outcome(response1, response2)
174
+ winner = determine_winner(outcome)
175
 
176
+ return response1, response2, history1, history2, shared_history, outcome, winner
177
+
178
+ def extract_text_from_pdf(pdf_file):
179
+ text = ""
180
+ doc = fitz.open(pdf_file)
181
+ for page in doc:
182
+ text += page.get_text()
183
+ return text
184
+
185
+ def ask_about_pdf(pdf_text, question):
186
+ result = qa_pipeline(question=question, context=pdf_text)
187
+ return result['answer']
188
+
189
+ def update_pdf_gallery_and_extract_text(pdf_files):
190
+ if len(pdf_files) > 0:
191
+ pdf_text = extract_text_from_pdf(pdf_files[0].name)
192
+ else:
193
+ pdf_text = ""
194
+ return pdf_files, pdf_text
195
 
196
  def get_top_10_cases():
197
+ # Here, I'm generating a list of 10 example cases. In a real-world scenario, you'd fetch this data from a database or another source.
198
+ cases = [
199
+ {"name": "Smith v. Jones", "number": "CA12345"},
200
+ {"name": "Johnson v. State", "number": "CA67890"},
201
+ {"name": "Doe v. Roe", "number": "CA11223"},
202
+ {"name": "Brown v. Davis", "number": "CA44556"},
203
+ {"name": "Williams v. Taylor", "number": "CA77889"},
204
+ {"name": "Miller v. Anderson", "number": "CA99100"},
205
+ {"name": "Davis v. Martin", "number": "CA22334"},
206
+ {"name": "Garcia v. Clark", "number": "CA55667"},
207
+ {"name": "Rodriguez v. Lewis", "number": "CA88990"},
208
+ {"name": "Martinez v. Lee", "number": "CA10112"}
209
+ ]
210
+ return "\n".join([f"{case['name']} - Case Number: {case['number']}" for case in cases])
211
+
212
+ def add_message(history, message):
213
+ for x in message["files"]:
214
+ history.append(((x,), None))
215
+ if message["text"] is not None:
216
+ history.append((message["text"], None))
217
+ return history, gr.MultimodalTextbox(value=None, interactive=False)
218
+
219
+ def bot(history):
220
+ system_message = "You are a helpful assistant."
221
+ messages = [{"role": "system", "content": system_message}]
222
+ for val in history:
223
+ if val[0]:
224
+ messages.append({"role": "user", "content": val[0]})
225
+ if val[1]:
226
+ messages.append({"role": "assistant", "content": val[1]})
227
  response = ""
228
  for message in client.chat_completion(
229
+ messages,
230
+ max_tokens=150,
 
231
  stream=True,
232
  temperature=0.6,
233
  top_p=0.95,
 
235
  token = message.choices[0].delta.content
236
  if token is not None:
237
  response += token
238
+ history[-1][1] = response
239
+ yield history
 
 
 
 
 
 
240
 
241
  def print_like_dislike(x: gr.LikeData):
242
  print(x.index, x.value, x.liked)
 
248
  return history1, history2, shared_history
249
 
250
  def ask_about_case_outcome(shared_history, question):
251
+ result = qa_pipeline(question=question, context=shared_history)
252
+ return result['answer']
 
 
 
 
 
 
 
 
 
 
 
 
253
 
254
  with gr.Blocks(css=custom_css) as demo:
255
  history1 = gr.State([])
256
  history2 = gr.State([])
257
  shared_history = gr.State([])
258
+ pdf_files = gr.State([])
259
+ pdf_text = gr.State("")
260
  top_10_cases = gr.State("")
261
 
262
  with gr.Tab("Argument Evaluation"):
263
  with gr.Row():
264
  with gr.Column(scale=1):
265
  top_10_btn = gr.Button("Give me the top 10 cases")
266
+ top_10_output = gr.Markdown(elem_classes=["scroll-box"])
267
  top_10_btn.click(get_top_10_cases, outputs=top_10_output)
268
  with gr.Column(scale=2):
269
  message = gr.Textbox(label="Case to Argue")
 
284
  with gr.Column(scale=1):
285
  defense_score_color = gr.HTML()
286
 
287
+ shared_argument = gr.Textbox(label="Case Outcome", interactive=True, elem_classes=["scroll-box"])
288
+ winner = gr.Textbox(label="Winner", interactive=False, elem_classes=["scroll-box"])
289
 
290
  with gr.Row():
291
  submit_btn = gr.Button("Argue")
292
  clear_btn = gr.Button("Clear and Reset")
293
  save_btn = gr.Button("Save Conversation")
294
 
295
+ submit_btn.click(chat_between_bots, inputs=[system_message1, system_message2, max_tokens, temperature, top_p, history1, history2, shared_history, message], outputs=[prosecutor_response, defense_response, history1, history2, shared_argument, winner])
296
+ clear_btn.click(reset_conversation, outputs=[history1, history2, shared_history, prosecutor_response, defense_response, shared_argument, winner])
297
  save_btn.click(save_conversation, inputs=[history1, history2, shared_history], outputs=[history1, history2, shared_history])
298
+
299
+ # Inner HTML for asking about the case outcome
300
+ with gr.Row():
301
+ case_question = gr.Textbox(label="Ask a Question about the Case Outcome")
302
+ case_answer = gr.Textbox(label="Answer", interactive=False, elem_classes=["scroll-box"])
303
+ ask_case_btn = gr.Button("Ask")
304
+
305
+ ask_case_btn.click(ask_about_case_outcome, inputs=[shared_history, case_question], outputs=case_answer)
306
+
307
+ with gr.Tab("PDF Management"):
308
+ pdf_upload = gr.File(label="Upload Case Files (PDF)", file_types=[".pdf"])
309
+ pdf_gallery = gr.Gallery(label="PDF Gallery")
310
+ pdf_view = gr.Textbox(label="PDF Content", interactive=False, elem_classes=["scroll-box"])
311
+ pdf_question = gr.Textbox(label="Ask a Question about the PDF")
312
+ pdf_answer = gr.Textbox(label="Answer", interactive=False, elem_classes=["scroll-box"])
313
+ pdf_upload_btn = gr.Button("Update PDF Gallery")
314
+ pdf_ask_btn = gr.Button("Ask")
315
+
316
+ pdf_upload_btn.click(update_pdf_gallery_and_extract_text, inputs=[pdf_upload], outputs=[pdf_gallery, pdf_text])
317
+ pdf_text.change(fn=lambda x: x, inputs=pdf_text, outputs=pdf_view)
318
+ pdf_ask_btn.click(ask_about_pdf, inputs=[pdf_text, pdf_question], outputs=pdf_answer)
319
+
320
+ with gr.Tab("Chatbot"):
321
+ chatbot = gr.Chatbot(
322
+ [],
323
+ elem_id="chatbot",
324
+ bubble_full_width=False
 
 
 
325
  )
326
 
327
+ chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False)
328
 
329
+ chat_msg = chat_input.submit(add_message, [chatbot, chat_input], [chatbot, chat_input])
330
+ bot_msg = chat_msg.then(bot, chatbot, chatbot, api_name="bot_response")
331
+ bot_msg.then(lambda: gr.MultimodalTextbox(interactive=True), None, [chat_input])
 
332
 
333
+ chatbot.like(print_like_dislike, None, None)
334
 
335
  demo.queue()
336
  demo.launch()