minko186 commited on
Commit
c1769c1
·
1 Parent(s): f6b1cb0

cleaned up output format + switch all records of text to new format

Browse files
Files changed (3) hide show
  1. ai_generate.py +52 -14
  2. app.py +21 -19
  3. humanize.py +33 -0
ai_generate.py CHANGED
@@ -111,18 +111,43 @@ def remove_citations(text):
111
  return text
112
 
113
 
114
- def process_cited_text(data, docs):
115
- # Initialize variables for the combined text and a dictionary for citations
116
  combined_text = ""
117
  citations = {}
118
  # Iterate through the cited_text list
119
  if 'cited_text' in data:
120
  for item in data['cited_text']:
121
- chunk_text = item['chunk'][0]['text']
122
- combined_text += chunk_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  citation_ids = []
124
- # Process the citations for the chunk
125
- if item['chunk'][1]['citations']:
126
  for c in item['chunk'][1]['citations']:
127
  if c and 'citation' in c:
128
  citation = c['citation']
@@ -133,16 +158,12 @@ def process_cited_text(data, docs):
133
  citation_ids.append(int(citation))
134
  except ValueError:
135
  pass # Handle cases where the string is not a valid integer
136
- if citation_ids:
137
- citation_texts = [f"<{cid}>" for cid in citation_ids]
138
- combined_text += " " + "".join(citation_texts)
139
- combined_text += "\n\n"
140
  # Store unique citations in a dictionary
141
  for citation_id in citation_ids:
142
  if citation_id not in citations:
143
  citations[citation_id] = {'source': docs[citation_id].metadata['source'], 'content': docs[citation_id].page_content}
144
 
145
- return combined_text.strip(), citations
146
 
147
 
148
  def citations_to_html(citations):
@@ -236,8 +257,8 @@ def generate_rag(
236
  | XMLOutputParser()
237
  )
238
  result = rag_chain.invoke({"input": prompt})
239
- text, citations = process_cited_text(result, docs)
240
- return text, citations
241
 
242
  def generate_base(
243
  prompt: str, topic: str, model: str, temperature: float, max_length: int, api_key: str, sys_message=""
@@ -248,7 +269,10 @@ def generate_base(
248
  return None, None
249
  try:
250
  output = llm.invoke(prompt).content
251
- return output, None
 
 
 
252
  except Exception as e:
253
  print(f"An error occurred while running the model: {e}")
254
  return None, None
@@ -269,3 +293,17 @@ def generate(
269
  return generate_rag(prompt, topic, model, url_content, path, temperature, max_length, api_key, sys_message)
270
  else:
271
  return generate_base(prompt, topic, model, temperature, max_length, api_key, sys_message)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  return text
112
 
113
 
114
+ def display_cited_text(data):
 
115
  combined_text = ""
116
  citations = {}
117
  # Iterate through the cited_text list
118
  if 'cited_text' in data:
119
  for item in data['cited_text']:
120
+ if 'chunk' in item and len(item['chunk']) > 0:
121
+ chunk_text = item['chunk'][0].get('text')
122
+ combined_text += chunk_text
123
+ citation_ids = []
124
+ # Process the citations for the chunk
125
+ if len(item['chunk']) > 1 and item['chunk'][1]['citations']:
126
+ for c in item['chunk'][1]['citations']:
127
+ if c and 'citation' in c:
128
+ citation = c['citation']
129
+ if isinstance(citation, dict) and "source_id" in citation:
130
+ citation = citation['source_id']
131
+ if isinstance(citation, str):
132
+ try:
133
+ citation_ids.append(int(citation))
134
+ except ValueError:
135
+ pass # Handle cases where the string is not a valid integer
136
+ if citation_ids:
137
+ citation_texts = [f"<{cid}>" for cid in citation_ids]
138
+ combined_text += " " + "".join(citation_texts)
139
+ combined_text += "\n\n"
140
+ return combined_text
141
+
142
+
143
+ def get_citations(data, docs):
144
+ # Initialize variables for the combined text and a dictionary for citations
145
+ citations = {}
146
+ # Iterate through the cited_text list
147
+ if data.get('cited_text'):
148
+ for item in data['cited_text']:
149
  citation_ids = []
150
+ if 'chunk' in item and len(item['chunk']) > 1 and item['chunk'][1].get('citations'):
 
151
  for c in item['chunk'][1]['citations']:
152
  if c and 'citation' in c:
153
  citation = c['citation']
 
158
  citation_ids.append(int(citation))
159
  except ValueError:
160
  pass # Handle cases where the string is not a valid integer
 
 
 
 
161
  # Store unique citations in a dictionary
162
  for citation_id in citation_ids:
163
  if citation_id not in citations:
164
  citations[citation_id] = {'source': docs[citation_id].metadata['source'], 'content': docs[citation_id].page_content}
165
 
166
+ return citations
167
 
168
 
169
  def citations_to_html(citations):
 
257
  | XMLOutputParser()
258
  )
259
  result = rag_chain.invoke({"input": prompt})
260
+ citations = get_citations(result, docs)
261
+ return result, citations
262
 
263
  def generate_base(
264
  prompt: str, topic: str, model: str, temperature: float, max_length: int, api_key: str, sys_message=""
 
269
  return None, None
270
  try:
271
  output = llm.invoke(prompt).content
272
+ output_dict = {'cited_text': [
273
+ {'chunk': [{'text': output}, {'citations': None}]}
274
+ ]}
275
+ return output_dict, None
276
  except Exception as e:
277
  print(f"An error occurred while running the model: {e}")
278
  return None, None
 
293
  return generate_rag(prompt, topic, model, url_content, path, temperature, max_length, api_key, sys_message)
294
  else:
295
  return generate_base(prompt, topic, model, temperature, max_length, api_key, sys_message)
296
+
297
+ # prompt = "Write a short 200 word report with an introduction about the current methods of ai detection and the results."
298
+ # topic = "the current methods of ai detection"
299
+
300
+ # text, citations = generate(
301
+ # prompt,
302
+ # topic,
303
+ # "OpenAI GPT 4o",
304
+ # None,
305
+ # ["./final_report.pdf","./detection_tools.pdf"],
306
+ # )
307
+ # from pprint import pprint
308
+ # print(text)
309
+ # print(citations)
app.py CHANGED
@@ -18,8 +18,8 @@ from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipe
18
 
19
  from utils import remove_special_characters, split_text_allow_complete_sentences_nltk
20
  from google_search import google_search, months, domain_list, build_date
21
- from humanize import humanize_text, device
22
- from ai_generate import generate, citations_to_html, remove_citations
23
  import nltk
24
  nltk.download('punkt_tab')
25
 
@@ -380,11 +380,14 @@ def generate_article(
380
  api_key=api_key,
381
  sys_message="",
382
  )
383
- return clean_text(article), citations_to_html(citations)
384
 
385
 
386
  def get_history(history):
387
- return history
 
 
 
388
 
389
 
390
  def clear_history():
@@ -393,7 +396,6 @@ def clear_history():
393
 
394
 
395
  def humanize(
396
- text: str,
397
  model: str,
398
  temperature: float = 1.2,
399
  repetition_penalty: float = 1,
@@ -402,21 +404,22 @@ def humanize(
402
  history=None,
403
  ) -> str:
404
  print("Humanizing text...")
405
- body, references = split_text_from_refs(text)
406
- result = humanize_text(
407
- text=body,
 
408
  model_name=model,
409
  temperature=temperature,
410
  repetition_penalty=repetition_penalty,
411
  top_k=top_k,
412
  length_penalty=length_penalty,
413
  )
414
- result = result + references
415
- corrected_text = format_and_correct_language_check(result)
416
 
417
  timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
418
- history.append((f"Humanized Text | {timestamp}\nInput: {model}", corrected_text))
419
- return corrected_text, history
420
 
421
 
422
  def update_visibility_api(model: str):
@@ -600,13 +603,13 @@ def generate_and_format(
600
  generated_article,
601
  user_comments,
602
  )
603
- if ends_with_references(article) and url_content is not None:
604
- for url in url_content.keys():
605
- article += f"\n{url}"
606
 
607
- reference_formatted = format_references(article)
608
  timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
609
- history.append((f"Generated Text | {timestamp}\nInput: {topic}", reference_formatted))
610
 
611
  # Save the article and metadata to Cloud Storage
612
  # We dont save if there is PDF input for privacy reasons
@@ -636,7 +639,7 @@ def generate_and_format(
636
  )
637
  print(save_message)
638
 
639
- return reference_formatted, citations, history
640
 
641
 
642
  def create_interface():
@@ -1024,7 +1027,6 @@ def create_interface():
1024
  humanize_btn.click(
1025
  fn=humanize,
1026
  inputs=[
1027
- output_article,
1028
  model_dropdown,
1029
  temperature_slider,
1030
  repetition_penalty_slider,
 
18
 
19
  from utils import remove_special_characters, split_text_allow_complete_sentences_nltk
20
  from google_search import google_search, months, domain_list, build_date
21
+ from humanize import humanize_text, device, humanize_chunk
22
+ from ai_generate import generate, citations_to_html, remove_citations, display_cited_text
23
  import nltk
24
  nltk.download('punkt_tab')
25
 
 
380
  api_key=api_key,
381
  sys_message="",
382
  )
383
+ return article, citations_to_html(citations)
384
 
385
 
386
  def get_history(history):
387
+ history_formatted = []
388
+ for entry in history:
389
+ history_formatted.append((entry[0], display_cited_text(entry[1])))
390
+ return history_formatted
391
 
392
 
393
  def clear_history():
 
396
 
397
 
398
  def humanize(
 
399
  model: str,
400
  temperature: float = 1.2,
401
  repetition_penalty: float = 1,
 
404
  history=None,
405
  ) -> str:
406
  print("Humanizing text...")
407
+ # body, references = split_text_from_refs(text)
408
+ cited_text = history[-1][1]
409
+ result = humanize_chunk(
410
+ data = cited_text,
411
  model_name=model,
412
  temperature=temperature,
413
  repetition_penalty=repetition_penalty,
414
  top_k=top_k,
415
  length_penalty=length_penalty,
416
  )
417
+ # result = result + references
418
+ # corrected_text = format_and_correct_language_check(result)
419
 
420
  timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
421
+ history.append((f"Humanized Text | {timestamp}\nInput: {model}", result))
422
+ return clean_text(display_cited_text(result)), history
423
 
424
 
425
  def update_visibility_api(model: str):
 
603
  generated_article,
604
  user_comments,
605
  )
606
+ # if ends_with_references(article) and url_content is not None:
607
+ # for url in url_content.keys():
608
+ # article += f"\n{url}"
609
 
610
+ # reference_formatted = format_references(article)
611
  timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
612
+ history.append((f"Generated Text | {timestamp}\nInput: {topic}", article))
613
 
614
  # Save the article and metadata to Cloud Storage
615
  # We dont save if there is PDF input for privacy reasons
 
639
  )
640
  print(save_message)
641
 
642
+ return clean_text(display_cited_text(article)), citations, history
643
 
644
 
645
  def create_interface():
 
1027
  humanize_btn.click(
1028
  fn=humanize,
1029
  inputs=[
 
1030
  model_dropdown,
1031
  temperature_slider,
1032
  repetition_penalty_slider,
humanize.py CHANGED
@@ -5,6 +5,7 @@ from nltk import sent_tokenize
5
  import gradio as gr
6
  from peft import PeftModel
7
  from transformers import T5ForConditionalGeneration, T5Tokenizer
 
8
 
9
  nltk.download("punkt")
10
 
@@ -49,6 +50,12 @@ FastLanguageModel.for_inference(dec_only_model) # native 2x faster inference
49
  print(f"Loaded model: {dec_only}, Num. params: {dec_only_model.num_parameters()}")
50
 
51
 
 
 
 
 
 
 
52
  def humanize_batch_seq2seq(model, tokenizer, sentences, temperature, repetition_penalty, top_k, length_penalty):
53
  inputs = ["Please paraphrase this sentence: " + sentence for sentence in sentences]
54
  inputs = tokenizer(inputs, return_tensors="pt", padding=True, truncation=True).to(model.device)
@@ -191,3 +198,29 @@ def humanize_text(
191
 
192
  humanized_text = "\n".join(humanized_paragraphs)
193
  return humanized_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import gradio as gr
6
  from peft import PeftModel
7
  from transformers import T5ForConditionalGeneration, T5Tokenizer
8
+ import language_tool_python
9
 
10
  nltk.download("punkt")
11
 
 
50
  print(f"Loaded model: {dec_only}, Num. params: {dec_only_model.num_parameters()}")
51
 
52
 
53
+ # grammar correction tool
54
+ tool = language_tool_python.LanguageTool("en-US")
55
+
56
+ def format_and_correct_language_check(text: str) -> str:
57
+ return tool.correct(text)
58
+
59
  def humanize_batch_seq2seq(model, tokenizer, sentences, temperature, repetition_penalty, top_k, length_penalty):
60
  inputs = ["Please paraphrase this sentence: " + sentence for sentence in sentences]
61
  inputs = tokenizer(inputs, return_tensors="pt", padding=True, truncation=True).to(model.device)
 
198
 
199
  humanized_text = "\n".join(humanized_paragraphs)
200
  return humanized_text
201
+
202
+
203
+ def humanize_chunk(
204
+ data,
205
+ progress=gr.Progress(),
206
+ model_name="Standard Model",
207
+ temperature=1.2,
208
+ repetition_penalty=1.0,
209
+ top_k=50,
210
+ length_penalty=1.0,
211
+ ):
212
+ humanized_chunks = {'cited_text': []}
213
+ if 'cited_text' in data:
214
+ for item in data['cited_text']:
215
+ humanized_chunk = {'chunk': [{'text': ""}, {'citations': None}]}
216
+ if 'chunk' in item and len(item['chunk']) > 0:
217
+ chunk_text = item['chunk'][0].get('text')
218
+ humanized_chunk['chunk'][0] = {'text': format_and_correct_language_check(humanize_text(chunk_text))}
219
+
220
+ citation_ids = []
221
+ # Process the citations for the chunk
222
+ if len(item['chunk']) > 1 and item['chunk'][1]['citations']:
223
+ humanized_chunk['chunk'][1] = {'citations': item['chunk'][1]['citations']}
224
+ humanized_chunks['cited_text'].append(humanized_chunk)
225
+ return humanized_chunks
226
+