eljanmahammadli commited on
Commit
b26a983
·
2 Parent(s): 88a1d09 80a07a7

Merge branch 'staging'

Browse files
Files changed (4) hide show
  1. ai_generate.py +246 -18
  2. app.py +740 -463
  3. humanize.py +58 -10
  4. requirements.txt +2 -1
ai_generate.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import os
2
  from langchain_community.document_loaders import PyMuPDFLoader
3
  from langchain_core.documents import Document
@@ -15,6 +16,16 @@ from langchain_openai import ChatOpenAI
15
  from langchain_google_genai import ChatGoogleGenerativeAI
16
  from langchain_anthropic import ChatAnthropic
17
  from dotenv import load_dotenv
 
 
 
 
 
 
 
 
 
 
18
 
19
  load_dotenv()
20
 
@@ -26,7 +37,17 @@ os.environ["GLOG_minloglevel"] = "2"
26
  CHUNK_SIZE = 1024
27
  CHUNK_OVERLAP = CHUNK_SIZE // 8
28
  K = 10
29
- FETCH_K = 20
 
 
 
 
 
 
 
 
 
 
30
 
31
  llm_model_translation = {
32
  "LLaMA 3": "llama3-70b-8192",
@@ -47,6 +68,138 @@ llm_classes = {
47
  }
48
 
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  def load_llm(model: str, api_key: str, temperature: float = 1.0, max_length: int = 2048):
51
  model_name = llm_model_translation.get(model)
52
  llm_class = llm_classes.get(model_name)
@@ -60,10 +213,9 @@ def load_llm(model: str, api_key: str, temperature: float = 1.0, max_length: int
60
  return llm
61
 
62
 
63
- def create_db_with_langchain(path: list[str], url_content: dict):
64
  all_docs = []
65
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP)
66
- embedding_function = SentenceTransformerEmbeddings(model_name="all-mpnet-base-v2")
67
  if path:
68
  for file in path:
69
  loader = PyMuPDFLoader(file)
@@ -79,18 +231,38 @@ def create_db_with_langchain(path: list[str], url_content: dict):
79
  docs = text_splitter.split_documents([doc])
80
  all_docs.extend(docs)
81
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  # print docs
83
  for idx, doc in enumerate(all_docs):
84
  print(f"Doc: {idx} | Length = {len(doc.page_content)}")
85
 
86
  assert len(all_docs) > 0, "No PDFs or scrapped data provided"
87
  db = Chroma.from_documents(all_docs, embedding_function)
 
 
88
  return db
89
 
90
 
 
 
 
 
91
  def generate_rag(
92
  prompt: str,
 
93
  topic: str,
 
94
  model: str,
95
  url_content: dict,
96
  path: list[str],
@@ -103,19 +275,25 @@ def generate_rag(
103
  if llm is None:
104
  print("Failed to load LLM. Aborting operation.")
105
  return None
106
- db = create_db_with_langchain(path, url_content)
107
- retriever = db.as_retriever(search_type="mmr", search_kwargs={"k": K, "fetch_k": FETCH_K})
108
- rag_prompt = hub.pull("rlm/rag-prompt")
109
 
110
- def format_docs(docs):
111
- return "\n\n".join(doc.page_content for doc in docs)
 
 
112
 
113
- docs = retriever.get_relevant_documents(topic)
114
- formatted_docs = format_docs(docs)
115
- rag_chain = (
116
- {"context": lambda _: formatted_docs, "question": RunnablePassthrough()} | rag_prompt | llm | StrOutputParser()
117
- )
118
- return rag_chain.invoke(prompt)
 
 
 
 
 
 
 
119
 
120
 
121
  def generate_base(
@@ -124,18 +302,21 @@ def generate_base(
124
  llm = load_llm(model, api_key, temperature, max_length)
125
  if llm is None:
126
  print("Failed to load LLM. Aborting operation.")
127
- return None
128
  try:
129
  output = llm.invoke(prompt).content
130
- return output
 
131
  except Exception as e:
132
  print(f"An error occurred while running the model: {e}")
133
- return None
134
 
135
 
136
  def generate(
137
  prompt: str,
 
138
  topic: str,
 
139
  model: str,
140
  url_content: dict,
141
  path: list[str],
@@ -145,6 +326,53 @@ def generate(
145
  sys_message="",
146
  ):
147
  if path or url_content:
148
- return generate_rag(prompt, topic, model, url_content, path, temperature, max_length, api_key, sys_message)
 
 
149
  else:
150
  return generate_base(prompt, topic, model, temperature, max_length, api_key, sys_message)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
  import os
3
  from langchain_community.document_loaders import PyMuPDFLoader
4
  from langchain_core.documents import Document
 
16
  from langchain_google_genai import ChatGoogleGenerativeAI
17
  from langchain_anthropic import ChatAnthropic
18
  from dotenv import load_dotenv
19
+ from langchain_core.output_parsers import XMLOutputParser
20
+ from langchain.prompts import ChatPromptTemplate
21
+ import re
22
+ import numpy as np
23
+ import torch
24
+ import bm25s
25
+ from langchain_community.cross_encoders import HuggingFaceCrossEncoder
26
+ from langchain.retrievers import ContextualCompressionRetriever
27
+ from langchain.retrievers.document_compressors import CrossEncoderReranker
28
+ from langchain_core.messages import HumanMessage
29
 
30
  load_dotenv()
31
 
 
37
  CHUNK_SIZE = 1024
38
  CHUNK_OVERLAP = CHUNK_SIZE // 8
39
  K = 10
40
+ FETCH_K = 50
41
+
42
+ model_kwargs = {"device": "cuda:1"}
43
+ print("Loading embedding and reranker models...")
44
+ embedding_function = SentenceTransformerEmbeddings(
45
+ model_name="mixedbread-ai/mxbai-embed-large-v1", model_kwargs=model_kwargs
46
+ )
47
+ # "sentence-transformers/all-MiniLM-L6-v2"
48
+ # "mixedbread-ai/mxbai-embed-large-v1"
49
+ reranker = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-base", model_kwargs=model_kwargs)
50
+ compressor = CrossEncoderReranker(model=reranker, top_n=K)
51
 
52
  llm_model_translation = {
53
  "LLaMA 3": "llama3-70b-8192",
 
68
  }
69
 
70
 
71
+ xml_system = """You're a helpful AI assistant. Given a user prompt and some related sources, fulfill all the requirements \
72
+ of the prompt and provide citations. If a chunk of the generated text does not use any of the sources (for example, \
73
+ introductions or general text), don't put a citation for that chunk and just leave "citations" section empty. Otherwise, \
74
+ list all sources used for that chunk of the text. Remember, don't add inline citations in the text itself in any circumstant.
75
+ Add all citations to the separate citations section. Use explicit new lines in the text to show paragraph splits. For each chunk use this example format:
76
+ <chunk>
77
+ <text>This is a sample text chunk....</text>
78
+ <citations>
79
+ <citation>1</citation>
80
+ <citation>3</citation>
81
+ ...
82
+ </citations>
83
+ </chunk>
84
+ If the prompt asks for a reference section, add it in a chunk without any citations
85
+ Return a citation for every quote across all articles that justify the text. Remember use the following format for your final output:
86
+ <cited_text>
87
+ <chunk>
88
+ <text></text>
89
+ <citations>
90
+ <citation><source_id></source_id></citation>
91
+ ...
92
+ </citations>
93
+ </chunk>
94
+ <chunk>
95
+ <text></text>
96
+ <citations>
97
+ <citation><source_id></source_id></citation>
98
+ ...
99
+ </citations>
100
+ </chunk>
101
+ ...
102
+ </cited_text>
103
+ The entire text should be wrapped in one cited_text. For References section (if asked by prompt), don't add citations.
104
+ For source id, give a valid integer alone without a key.
105
+ Here are the sources:{context}"""
106
+ xml_prompt = ChatPromptTemplate.from_messages([("system", xml_system), ("human", "{input}")])
107
+
108
+
109
+ def format_docs_xml(docs: list[Document]) -> str:
110
+ formatted = []
111
+ for i, doc in enumerate(docs):
112
+ doc_str = f"""\
113
+ <source id=\"{i}\">
114
+ <path>{doc.metadata['source']}</path>
115
+ <article_snippet>{doc.page_content}</article_snippet>
116
+ </source>"""
117
+ formatted.append(doc_str)
118
+ return "\n\n<sources>" + "\n".join(formatted) + "</sources>"
119
+
120
+
121
+ def get_doc_content(docs, id):
122
+ return docs[id].page_content
123
+
124
+
125
+ def remove_citations(text):
126
+ text = re.sub(r"<\d+>", "", text)
127
+ return text
128
+
129
+
130
+ def display_cited_text(data):
131
+ combined_text = ""
132
+ citations = {}
133
+ # Iterate through the cited_text list
134
+ if "cited_text" in data:
135
+ for item in data["cited_text"]:
136
+ if "chunk" in item and len(item["chunk"]) > 0:
137
+ chunk_text = item["chunk"][0].get("text")
138
+ combined_text += chunk_text
139
+ citation_ids = []
140
+ # Process the citations for the chunk
141
+ if len(item["chunk"]) > 1 and item["chunk"][1]["citations"]:
142
+ for c in item["chunk"][1]["citations"]:
143
+ if c and "citation" in c:
144
+ citation = c["citation"]
145
+ if isinstance(citation, dict) and "source_id" in citation:
146
+ citation = citation["source_id"]
147
+ if isinstance(citation, str):
148
+ try:
149
+ citation_ids.append(int(citation))
150
+ except ValueError:
151
+ pass # Handle cases where the string is not a valid integer
152
+ if citation_ids:
153
+ citation_texts = [f"<{cid}>" for cid in citation_ids]
154
+ combined_text += " " + "".join(citation_texts)
155
+ combined_text += "\n\n"
156
+ return combined_text
157
+
158
+
159
+ def get_citations(data, docs):
160
+ # Initialize variables for the combined text and a dictionary for citations
161
+ citations = {}
162
+ # Iterate through the cited_text list
163
+ if data.get("cited_text"):
164
+ for item in data["cited_text"]:
165
+ citation_ids = []
166
+ if "chunk" in item and len(item["chunk"]) > 1 and item["chunk"][1].get("citations"):
167
+ for c in item["chunk"][1]["citations"]:
168
+ if c and "citation" in c:
169
+ citation = c["citation"]
170
+ if isinstance(citation, dict) and "source_id" in citation:
171
+ citation = citation["source_id"]
172
+ if isinstance(citation, str):
173
+ try:
174
+ citation_ids.append(int(citation))
175
+ except ValueError:
176
+ pass # Handle cases where the string is not a valid integer
177
+ # Store unique citations in a dictionary
178
+ for citation_id in citation_ids:
179
+ if citation_id not in citations:
180
+ citations[citation_id] = {
181
+ "source": docs[citation_id].metadata["source"],
182
+ "content": docs[citation_id].page_content,
183
+ }
184
+
185
+ return citations
186
+
187
+
188
+ def citations_to_html(citations):
189
+ if citations:
190
+ # Generate the HTML for the unique citations
191
+ html_content = ""
192
+ for citation_id, citation_info in citations.items():
193
+ html_content += (
194
+ f"<li><strong>Source ID:</strong> {citation_id}<br>"
195
+ f"<strong>Path:</strong> {citation_info['source']}<br>"
196
+ f"<strong>Page Content:</strong> {citation_info['content']}</li>"
197
+ )
198
+ html_content += "</ul></body></html>"
199
+ return html_content
200
+ return ""
201
+
202
+
203
  def load_llm(model: str, api_key: str, temperature: float = 1.0, max_length: int = 2048):
204
  model_name = llm_model_translation.get(model)
205
  llm_class = llm_classes.get(model_name)
 
213
  return llm
214
 
215
 
216
+ def create_db_with_langchain(path: list[str], url_content: dict, query: str):
217
  all_docs = []
218
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP)
 
219
  if path:
220
  for file in path:
221
  loader = PyMuPDFLoader(file)
 
231
  docs = text_splitter.split_documents([doc])
232
  all_docs.extend(docs)
233
 
234
+ print(f"### Total number of documents before bm25s: {len(all_docs)}")
235
+
236
+ # if the number of docs is too high, we need to reduce it
237
+ num_max_docs = 300
238
+ if len(all_docs) > num_max_docs:
239
+ docs_raw = [doc.page_content for doc in all_docs]
240
+ retriever = bm25s.BM25(corpus=docs_raw)
241
+ retriever.index(bm25s.tokenize(docs_raw))
242
+ results, scores = retriever.retrieve(bm25s.tokenize(query), k=len(docs_raw), sorted=False)
243
+ top_indices = np.argpartition(scores[0], -num_max_docs)[-num_max_docs:]
244
+ all_docs = [all_docs[i] for i in top_indices]
245
+
246
  # print docs
247
  for idx, doc in enumerate(all_docs):
248
  print(f"Doc: {idx} | Length = {len(doc.page_content)}")
249
 
250
  assert len(all_docs) > 0, "No PDFs or scrapped data provided"
251
  db = Chroma.from_documents(all_docs, embedding_function)
252
+ torch.cuda.empty_cache()
253
+ gc.collect()
254
  return db
255
 
256
 
257
+ def pretty_print_docs(docs):
258
+ print(f"\n{'-' * 100}\n".join([f"Document {i+1}:\n\n" + d.page_content for i, d in enumerate(docs)]))
259
+
260
+
261
  def generate_rag(
262
  prompt: str,
263
+ input_role: str,
264
  topic: str,
265
+ context: str,
266
  model: str,
267
  url_content: dict,
268
  path: list[str],
 
275
  if llm is None:
276
  print("Failed to load LLM. Aborting operation.")
277
  return None
 
 
 
278
 
279
+ query = llm_wrapper(input_role, topic, context, model="OpenAI GPT 4o", task_type="rag", temperature=0.7)
280
+ print("### Query: ", query)
281
+ db = create_db_with_langchain(path, url_content, query)
282
+ retriever = db.as_retriever(search_type="mmr", search_kwargs={"k": K, "fetch_k": FETCH_K, "lambda_mult": 0.75})
283
 
284
+ # docs = retriever.get_relevant_documents(query)
285
+ compression_retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=retriever)
286
+ docs = compression_retriever.invoke(query)
287
+ print(pretty_print_docs(docs))
288
+
289
+ formatted_docs = format_docs_xml(docs)
290
+ rag_chain = RunnablePassthrough.assign(context=lambda _: formatted_docs) | xml_prompt | llm | XMLOutputParser()
291
+ result = rag_chain.invoke({"input": prompt})
292
+ citations = get_citations(result, docs)
293
+ db.delete_collection() # important, othwerwise it will keep the documents in memory
294
+ torch.cuda.empty_cache()
295
+ gc.collect()
296
+ return result, citations
297
 
298
 
299
  def generate_base(
 
302
  llm = load_llm(model, api_key, temperature, max_length)
303
  if llm is None:
304
  print("Failed to load LLM. Aborting operation.")
305
+ return None, None
306
  try:
307
  output = llm.invoke(prompt).content
308
+ output_dict = {"cited_text": [{"chunk": [{"text": output}, {"citations": None}]}]}
309
+ return output_dict, None
310
  except Exception as e:
311
  print(f"An error occurred while running the model: {e}")
312
+ return None, None
313
 
314
 
315
  def generate(
316
  prompt: str,
317
+ input_role: str,
318
  topic: str,
319
+ context: str,
320
  model: str,
321
  url_content: dict,
322
  path: list[str],
 
326
  sys_message="",
327
  ):
328
  if path or url_content:
329
+ return generate_rag(
330
+ prompt, input_role, topic, context, model, url_content, path, temperature, max_length, api_key, sys_message
331
+ )
332
  else:
333
  return generate_base(prompt, topic, model, temperature, max_length, api_key, sys_message)
334
+
335
+
336
+ def llm_wrapper(
337
+ iam=None,
338
+ topic=None,
339
+ context=None,
340
+ temperature=1.0,
341
+ max_length=512,
342
+ api_key="",
343
+ model="OpenAI GPT 4o Mini",
344
+ task_type="internet",
345
+ ):
346
+ llm = load_llm(model, api_key, temperature, max_length)
347
+
348
+ if task_type == "rag":
349
+ system_message_content = """You are an AI assistant tasked with reformulating user inputs to improve retrieval query in a RAG system.
350
+ - Given the original user inputs, construct query to be more specific, detailed, and likely to retrieve relevant information.
351
+ - Generate the query as a complete sentence or question, not just as keywords, to ensure the retrieval process can find detailed and contextually relevant information.
352
+ - You may enhance the query by adding related and relevant terms, but do not introduce new facts, such as dates, numbers, or assumed information, that were not provided in the input.
353
+
354
+ **Inputs:**
355
+ - **User Role**: {iam}
356
+ - **Topic**: {topic}
357
+ - **Context**: {context}
358
+
359
+ **Only return the search query**."""
360
+ elif task_type == "internet":
361
+ system_message_content = """You are an AI assistant tasked with generating an optimized Google search query to help retrieve relevant websites, news, articles, and other sources of information.
362
+ - You may enhance the query by adding related and relevant terms, but do not introduce new facts, such as dates, numbers, or assumed information, that were not provided in the input.
363
+ - The query should be **concise** and include important **keywords** while incorporating **short phrases** or context where it improves the search.
364
+ - Avoid the use of "site:" operators or narrowing search by specific websites.
365
+
366
+ **Inputs:**
367
+ - **User Role**: {iam}
368
+ - **Topic**: {topic}
369
+ - **Context**: {context}
370
+
371
+ **Only return the search query**.
372
+ """
373
+ else:
374
+ raise ValueError("Task type not recognized. Please specify 'rag' or 'internet'.")
375
+
376
+ human_message = HumanMessage(content=system_message_content.format(iam=iam, topic=topic, context=context))
377
+ response = llm.invoke([human_message])
378
+ return response.content.strip('"').strip("'")
app.py CHANGED
@@ -3,41 +3,261 @@ nohup python3 app.py &
3
  export GOOGLE_APPLICATION_CREDENTIALS="gcp_creds.json"
4
  """
5
 
 
6
  import re
 
 
7
  from typing import Dict
8
  from collections import defaultdict
9
  from datetime import date, datetime
10
 
11
- import gradio as gr
12
  import nltk
13
  import torch
14
  import numpy as np
15
- from scipy.special import softmax
16
  import language_tool_python
17
- from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- from utils import remove_special_characters, split_text_allow_complete_sentences_nltk
20
- from google_search import google_search, months, domain_list, build_date
21
- from humanize import humanize_text, device
22
- from ai_generate import generate
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- print(f"Using device: {device}")
 
 
 
25
 
26
- models = {
27
- "Polygraf AI (Base Model)": AutoModelForSequenceClassification.from_pretrained(
28
- "polygraf-ai/bc-roberta-openai-2sent"
29
- ).to(device),
30
- "Polygraf AI (Advanced Model)": AutoModelForSequenceClassification.from_pretrained(
31
- "polygraf-ai/bc_combined_3sent"
32
- ).to(device),
33
- }
34
- tokenizers = {
35
- "Polygraf AI (Base Model)": AutoTokenizer.from_pretrained("polygraf-ai/bc-roberta-openai-2sent"),
36
- "Polygraf AI (Advanced Model)": AutoTokenizer.from_pretrained("polygraf-ai/bc_combined_3sent"),
37
- }
38
 
39
- # grammar correction tool
40
- tool = language_tool_python.LanguageTool("en-US")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
 
43
  # Function to move model to the appropriate device
@@ -62,7 +282,8 @@ def clean_text(text: str) -> str:
62
  cleaned = re.sub(r"\s+", " ", paragraph).strip()
63
  cleaned = re.sub(r"(?<=\.) ([a-z])", lambda x: x.group(1).upper(), cleaned)
64
  cleaned_paragraphs.append(cleaned)
65
- return "\n".join(cleaned_paragraphs)
 
66
 
67
 
68
  def format_references(text: str) -> str:
@@ -137,6 +358,8 @@ def predict(model, tokenizer, text):
137
  output = model(**tokens)
138
  output_norm = softmax(output.logits.detach().cpu().numpy(), 1)[0]
139
  output_norm = {"HUMAN": output_norm[0], "AI": output_norm[1]}
 
 
140
  return output_norm
141
 
142
 
@@ -196,13 +419,6 @@ ai_check_options = [
196
  ]
197
 
198
 
199
- MC_TOKEN_SIZE = 256
200
- TEXT_MC_MODEL_PATH = "polygraf-ai/mc-model"
201
- MC_LABEL_MAP = ["OpenAI GPT", "Mistral", "CLAUDE", "Gemini", "Grammar Enhancer"]
202
- text_mc_tokenizer = AutoTokenizer.from_pretrained(TEXT_MC_MODEL_PATH)
203
- text_mc_model = AutoModelForSequenceClassification.from_pretrained(TEXT_MC_MODEL_PATH).to(device)
204
-
205
-
206
  def predict_mc(text):
207
  with torch.no_grad():
208
  text_mc_model.eval()
@@ -215,6 +431,8 @@ def predict_mc(text):
215
  ).to(device)
216
  output = text_mc_model(**tokens)
217
  output_norm = softmax(output.logits.detach().cpu().numpy(), 1)[0]
 
 
218
  return output_norm
219
 
220
 
@@ -244,6 +462,7 @@ def predict_mc_scores(input, bc_score):
244
 
245
 
246
  def highlighter_polygraf(text, model="Polygraf AI (Base Model)"):
 
247
  body, references = split_text_from_refs(text)
248
  score, text = detection_polygraf(text=body, model=model)
249
  mc_score = predict_mc_scores(body, score) # mc score
@@ -251,7 +470,8 @@ def highlighter_polygraf(text, model="Polygraf AI (Base Model)"):
251
  return score, text, mc_score
252
 
253
 
254
- def ai_check(text: str, option: str):
 
255
  if option.startswith("Polygraf AI"):
256
  return highlighter_polygraf(text, option)
257
  else:
@@ -259,35 +479,39 @@ def ai_check(text: str, option: str):
259
 
260
 
261
  def generate_prompt(settings: Dict[str, str]) -> str:
 
 
262
  prompt = f"""
263
- I am a {settings['role']}
264
- Write a {settings['article_length']} words (around) {settings['format']} on {settings['topic']}.
 
 
265
  Context:
266
  - {settings['context']}
267
-
 
268
  Style and Tone:
269
  - Writing style: {settings['writing_style']}
270
  - Tone: {settings['tone']}
271
  - Target audience: {settings['user_category']}
272
-
273
  Content:
274
  - Depth: {settings['depth_of_content']}
275
  - Structure: {', '.join(settings['structure'])}
276
-
 
 
277
  Keywords to incorporate:
278
  {', '.join(settings['keywords'])}
279
-
 
280
  Additional requirements:
281
  - Don't start with "Here is a...", start with the requested text directly
282
- - Include {settings['num_examples']} relevant examples or case studies
283
- - Incorporate data or statistics from {', '.join(settings['references'])}
284
  - End with a {settings['conclusion_type']} conclusion
285
- - Add a "References" section in the format "References:" on a new line at the end with at least 3 credible detailed sources, formatted as [1], [2], etc. with each source on their own line
286
- - Do not repeat sources
287
  - Do not make any headline, title bold.
288
-
289
- Ensure proper paragraph breaks for better readability.
290
- Avoid any references to artificial intelligence, language models, or the fact that this is generated by an AI, and do not mention something like here is the article etc.
291
  """
292
  return prompt
293
 
@@ -299,7 +523,7 @@ def regenerate_prompt(settings: Dict[str, str]) -> str:
299
  Edit the given text based on user comments.
300
  User Comments:
301
  - {settings['user_comments']}
302
-
303
  Requirements:
304
  - Don't start with "Here is a...", start with the requested text directly
305
  - The original content should not be changed. Make minor modifications based on user comments above.
@@ -307,7 +531,7 @@ def regenerate_prompt(settings: Dict[str, str]) -> str:
307
  - Do not make any headline, title bold.
308
  Context:
309
  - {settings['context']}
310
-
311
  Ensure proper paragraph breaks for better readability.
312
  Avoid any references to artificial intelligence, language models, or the fact that this is generated by an AI, and do not mention something like here is the article etc.
313
  """
@@ -361,23 +585,29 @@ def generate_article(
361
  prompt = generate_prompt(settings)
362
 
363
  print("Generated Prompt...\n", prompt)
364
- article = generate(
365
  prompt=prompt,
 
366
  topic=topic,
 
367
  model=ai_model,
368
  url_content=url_content,
369
  path=pdf_file_input,
 
370
  temperature=1,
371
  max_length=2048,
372
  api_key=api_key,
373
  sys_message="",
374
  )
375
-
376
- return clean_text(article)
377
 
378
 
379
  def get_history(history):
380
- return history
 
 
 
 
381
 
382
 
383
  def clear_history():
@@ -386,8 +616,8 @@ def clear_history():
386
 
387
 
388
  def humanize(
389
- text: str,
390
  model: str,
 
391
  temperature: float = 1.2,
392
  repetition_penalty: float = 1,
393
  top_k: int = 50,
@@ -395,21 +625,35 @@ def humanize(
395
  history=None,
396
  ) -> str:
397
  print("Humanizing text...")
398
- body, references = split_text_from_refs(text)
399
- result = humanize_text(
400
- text=body,
 
 
401
  model_name=model,
402
  temperature=temperature,
403
  repetition_penalty=repetition_penalty,
404
  top_k=top_k,
405
  length_penalty=length_penalty,
406
  )
407
- result = result + references
408
- corrected_text = format_and_correct_language_check(result)
409
-
410
  timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
411
- history.append((f"Humanized Text | {timestamp}\nInput: {model}", corrected_text))
412
- return corrected_text, history
 
 
 
 
 
 
 
 
 
 
 
 
413
 
414
 
415
  def update_visibility_api(model: str):
@@ -445,11 +689,6 @@ def update_temperature(model_dropdown):
445
  return gr.update(value=1.0, interactive=True)
446
 
447
 
448
- import uuid
449
- import json
450
- from datetime import datetime
451
- from google.cloud import storage
452
-
453
  # Initialize Google Cloud Storage client
454
  client = storage.Client()
455
  bucket_name = "ai-source-detection"
@@ -460,7 +699,6 @@ def save_to_cloud_storage(
460
  article,
461
  topic,
462
  input_role,
463
- topic_context,
464
  context,
465
  keywords,
466
  article_length,
@@ -493,7 +731,6 @@ def save_to_cloud_storage(
493
  "metadata": {
494
  "topic": topic,
495
  "input_role": input_role,
496
- "topic_context": topic_context,
497
  "context": context,
498
  "keywords": keywords,
499
  "article_length": article_length,
@@ -524,6 +761,31 @@ def save_to_cloud_storage(
524
  return f"Data saved as {file_name} in GCS."
525
 
526
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
527
  def generate_and_format(
528
  input_role,
529
  topic,
@@ -561,7 +823,9 @@ def generate_and_format(
561
  date_from = build_date(year_from, month_from, day_from)
562
  date_to = build_date(year_to, month_to, day_to)
563
  sorted_date = f"date:r:{date_from}:{date_to}"
564
- final_query = topic
 
 
565
  if include_sites:
566
  site_queries = [f"site:{site.strip()}" for site in include_sites.split(",")]
567
  final_query += " " + " OR ".join(site_queries)
@@ -570,10 +834,10 @@ def generate_and_format(
570
  final_query += " " + " ".join(exclude_queries)
571
  print(f"Google Search Query: {final_query}")
572
  url_content = google_search(final_query, sorted_date, domains_to_include)
573
- topic_context = topic + ", " + context
574
- article = generate_article(
575
  input_role,
576
- topic_context,
577
  context,
578
  keywords,
579
  article_length,
@@ -593,13 +857,14 @@ def generate_and_format(
593
  generated_article,
594
  user_comments,
595
  )
596
- if ends_with_references(article) and url_content is not None:
597
- for url in url_content.keys():
598
- article += f"\n{url}"
599
 
600
- reference_formatted = format_references(article)
 
601
  timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
602
- history.append((f"Generated Text | {timestamp}\nInput: {topic}", reference_formatted))
603
 
604
  # Save the article and metadata to Cloud Storage
605
  # We dont save if there is PDF input for privacy reasons
@@ -608,7 +873,6 @@ def generate_and_format(
608
  article,
609
  topic,
610
  input_role,
611
- topic_context,
612
  context,
613
  keywords,
614
  article_length,
@@ -628,415 +892,428 @@ def generate_and_format(
628
  timestamp,
629
  )
630
  print(save_message)
631
-
632
- return reference_formatted, history
633
-
634
-
635
- def create_interface():
636
- with gr.Blocks(
637
- theme=gr.themes.Default(
638
- primary_hue=gr.themes.colors.pink, secondary_hue=gr.themes.colors.yellow, neutral_hue=gr.themes.colors.gray
639
- ),
640
- css="""
641
- .input-highlight-pink block_label {background-color: #008080}
642
- """,
643
- ) as demo:
644
- history = gr.State([])
645
- today = date.today()
646
- # dd/mm/YY
647
- d1 = today.strftime("%d/%B/%Y")
648
- d1 = d1.split("/")
649
- gr.Markdown("# Polygraf AI Content Writer", elem_classes="text-center text-3xl mb-6")
650
-
651
- with gr.Row():
652
- with gr.Column(scale=2):
653
- with gr.Group():
654
- gr.Markdown("## Article Configuration", elem_classes="text-xl mb-4")
655
- input_role = gr.Textbox(label="I am a", placeholder="Enter your role", value="Student")
656
- input_topic = gr.Textbox(
657
- label="Topic",
658
- placeholder="Enter the main topic of your article",
659
- elem_classes="input-highlight-pink",
660
- )
661
- input_context = gr.Textbox(
662
- label="Context",
663
- placeholder="Provide some context for your topic",
664
- elem_classes="input-highlight-pink",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
665
  )
666
- input_keywords = gr.Textbox(
667
- label="Keywords",
668
- placeholder="Enter comma-separated keywords",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
669
  elem_classes="input-highlight-yellow",
670
  )
 
 
 
 
 
 
671
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
672
  with gr.Row():
673
- input_format = gr.Dropdown(
674
- choices=[
675
- "Article",
676
- "Essay",
677
- "Blog post",
678
- "Report",
679
- "Research paper",
680
- "News article",
681
- "White paper",
682
- "Email",
683
- "LinkedIn post",
684
- "X (Twitter) post",
685
- "Instagram Video Content",
686
- "TikTok Video Content",
687
- "Facebook post",
688
- ],
689
- value="Article",
690
- label="Format",
691
- elem_classes="input-highlight-turquoise",
692
  )
693
-
694
- input_length = gr.Slider(
695
- minimum=50,
696
- maximum=5000,
697
- step=50,
698
- value=300,
699
- label="Article Length",
700
- elem_classes="input-highlight-pink",
701
- )
702
-
703
  with gr.Row():
704
- input_writing_style = gr.Dropdown(
705
- choices=[
706
- "Formal",
707
- "Informal",
708
- "Technical",
709
- "Conversational",
710
- "Journalistic",
711
- "Academic",
712
- "Creative",
713
- ],
714
- value="Formal",
715
- label="Writing Style",
716
  elem_classes="input-highlight-yellow",
717
  )
718
- input_tone = gr.Dropdown(
719
- choices=["Friendly", "Professional", "Neutral", "Enthusiastic", "Skeptical", "Humorous"],
720
- value="Professional",
721
- label="Tone",
722
- elem_classes="input-highlight-turquoise",
 
 
 
 
 
 
 
 
723
  )
 
 
724
 
725
- input_user_category = gr.Dropdown(
726
- choices=[
727
- "Students",
728
- "Professionals",
729
- "Researchers",
730
- "General Public",
731
- "Policymakers",
732
- "Entrepreneurs",
733
- ],
734
- value="General Public",
735
- label="Target Audience",
736
- elem_classes="input-highlight-pink",
737
- )
738
- input_depth = gr.Dropdown(
739
- choices=[
740
- "Surface-level overview",
741
- "Moderate analysis",
742
- "In-depth research",
743
- "Comprehensive study",
744
- ],
745
- value="Moderate analysis",
746
- label="Depth of Content",
747
- elem_classes="input-highlight-yellow",
748
- )
749
- input_structure = gr.Dropdown(
750
- choices=[
751
- "Introduction, Body, Conclusion",
752
- "Abstract, Introduction, Methods, Results, Discussion, Conclusion",
753
- "Executive Summary, Problem Statement, Analysis, Recommendations, Conclusion",
754
- "Introduction, Literature Review, Methodology, Findings, Analysis, Conclusion",
755
- "Plain Text",
756
- ],
757
- value="Introduction, Body, Conclusion",
758
- label="Structure",
759
- elem_classes="input-highlight-turquoise",
760
- interactive=True,
761
- )
762
- input_references = gr.Dropdown(
763
- choices=[
764
- "Academic journals",
765
- "Industry reports",
766
- "Government publications",
767
- "News outlets",
768
- "Expert interviews",
769
- "Case studies",
770
- ],
771
- value="News outlets",
772
- label="References",
773
- elem_classes="input-highlight-pink",
774
- )
775
- input_num_examples = gr.Dropdown(
776
- choices=["1-2", "3-4", "5+"],
777
- value="1-2",
778
- label="Number of Examples/Case Studies",
779
- elem_classes="input-highlight-yellow",
780
- )
781
- input_conclusion = gr.Dropdown(
782
- choices=["Summary", "Call to Action", "Future Outlook", "Thought-provoking Question"],
783
- value="Call to Action",
784
- label="Conclusion Type",
785
- elem_classes="input-highlight-turquoise",
786
  )
787
- gr.Markdown("# Search Options", elem_classes="text-center text-3xl mb-6")
788
- google_default = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
789
  with gr.Row():
790
- google_search_check = gr.Checkbox(
791
- label="Enable Internet Search For Recent Sources", value=google_default
792
  )
793
- with gr.Group(visible=google_default) as search_options:
794
- with gr.Row():
795
- include_sites = gr.Textbox(
796
- label="Include Specific Websites",
797
- placeholder="Enter comma-separated keywords",
798
- elem_classes="input-highlight-yellow",
799
- )
800
- with gr.Row():
801
- exclude_sites = gr.Textbox(
802
- label="Exclude Specific Websites",
803
- placeholder="Enter comma-separated keywords",
804
- elem_classes="input-highlight-yellow",
805
- )
806
- with gr.Row():
807
- domains_to_include = gr.Dropdown(
808
- domain_list,
809
- value=domain_list,
810
- multiselect=True,
811
- label="Domains To Include",
812
- )
813
- with gr.Row():
814
- month_from = gr.Dropdown(
815
- choices=months,
816
- label="From Month",
817
- value="January",
818
- interactive=True,
819
- )
820
- day_from = gr.Textbox(label="From Day", value="01")
821
- year_from = gr.Textbox(label="From Year", value="2000")
822
-
823
- with gr.Row():
824
- month_to = gr.Dropdown(
825
- choices=months,
826
- label="To Month",
827
- value=d1[1],
828
- interactive=True,
829
- )
830
- day_to = gr.Textbox(label="To Day", value=d1[0])
831
- year_to = gr.Textbox(label="To Year", value=d1[2])
832
-
833
- gr.Markdown("# Add Optional PDF Files with Information", elem_classes="text-center text-3xl mb-6")
834
- pdf_file_input = gr.File(label="Upload PDF(s)", file_count="multiple", file_types=[".pdf"])
835
  """
836
- # NOTE: HIDE AI MODEL SELECTION
837
- with gr.Group():
838
- gr.Markdown("## AI Model Configuration", elem_classes="text-xl mb-4")
839
- ai_generator = gr.Dropdown(
840
- choices=[
841
- "OpenAI GPT 4",
842
- "OpenAI GPT 4o",
843
- "OpenAI GPT 4o Mini",
844
- "Claude Sonnet 3.5",
845
- "Gemini 1.5 Pro",
846
- "LLaMA 3",
847
- ],
848
- value="OpenAI GPT 4o Mini",
849
- label="AI Model",
850
- elem_classes="input-highlight-pink",
851
- )
852
- input_api = gr.Textbox(label="API Key", visible=False)
853
- ai_generator.change(update_visibility_api, ai_generator, input_api)
854
  """
855
- generate_btn = gr.Button("Generate Article", variant="primary")
856
 
857
- with gr.Column(scale=3):
858
- with gr.Tab("Text Generator"):
859
- output_article = gr.Textbox(label="Generated Article", lines=20)
860
- ai_comments = gr.Textbox(
861
- label="Add comments to help edit generated text", interactive=True, visible=False
862
- )
863
- regenerate_btn = gr.Button("Regenerate Article", variant="primary", visible=False)
864
- ai_detector_dropdown = gr.Radio(
865
- choices=ai_check_options, label="Select AI Detector", value="Polygraf AI"
866
- )
867
- ai_check_btn = gr.Button("AI Check")
868
-
869
- with gr.Accordion("AI Detection Results", open=True):
870
- ai_check_result = gr.Label(label="AI Check Result")
871
- mc_check_result = gr.Label(label="Creator Check Result")
872
- highlighted_text = gr.HTML(label="Sentence Breakdown", visible=False)
873
-
874
- with gr.Accordion("Advanced Humanizer Settings", open=False):
875
- with gr.Row():
876
- model_dropdown = gr.Radio(
877
- choices=["Standard Model", "Advanced Model (Beta)"],
878
- value="Advanced Model (Beta)",
879
- label="Humanizer Model Version",
880
- )
881
- with gr.Row():
882
- temperature_slider = gr.Slider(
883
- minimum=0.5, maximum=2.0, step=0.1, value=1.0, label="Temperature"
884
- )
885
- top_k_slider = gr.Slider(minimum=0, maximum=300, step=25, value=40, label="Top k")
886
- with gr.Row():
887
- repetition_penalty_slider = gr.Slider(
888
- minimum=1.0, maximum=2.0, step=0.1, value=1, label="Repetition Penalty"
889
- )
890
- length_penalty_slider = gr.Slider(
891
- minimum=0.0, maximum=2.0, step=0.1, value=1.0, label="Length Penalty"
892
- )
893
-
894
- humanize_btn = gr.Button("Humanize")
895
- # humanized_output = gr.Markdown(label="Humanized Article", value="\n\n\n\n", render=True)
896
- # copy_to_input_btn = gr.Button("Copy to Input for AI Check")
897
-
898
- with gr.Tab("History"):
899
- history_chat = gr.Chatbot(label="Generation History", height=1000)
900
- clear_history_btn = gr.Button("Clear History")
901
- clear_history_btn.click(clear_history, outputs=[history, history_chat])
902
- """
903
- # NOTE: REMOVED REFRESH BUTTON
904
- refresh_button = gr.Button("Refresh History")
905
- refresh_button.click(get_history, outputs=history_chat)
906
- """
907
-
908
- def regenerate_visible(text):
909
- if text:
910
- return gr.update(visible=True)
911
- else:
912
- return gr.update(visible=False)
913
-
914
- def highlight_visible(text):
915
- if text.startswith("Polygraf"):
916
- return gr.update(visible=True)
917
- else:
918
- return gr.update(visible=False)
919
-
920
- def search_visible(toggle):
921
- if toggle:
922
- return gr.update(visible=True)
923
- else:
924
- return gr.update(visible=False)
925
-
926
- google_search_check.change(search_visible, inputs=google_search_check, outputs=search_options)
927
- ai_detector_dropdown.change(highlight_visible, inputs=ai_detector_dropdown, outputs=highlighted_text)
928
- output_article.change(regenerate_visible, inputs=output_article, outputs=ai_comments)
929
- ai_comments.change(regenerate_visible, inputs=output_article, outputs=regenerate_btn)
930
- ai_check_btn.click(highlight_visible, inputs=ai_detector_dropdown, outputs=highlighted_text)
931
-
932
- # Update the default structure based on the selected format
933
- # e.g. "Plain Text" for certain formats
934
- input_format.change(fn=update_structure, inputs=input_format, outputs=input_structure)
935
- model_dropdown.change(fn=update_temperature, inputs=model_dropdown, outputs=temperature_slider)
936
-
937
- generate_btn.click(
938
- fn=generate_and_format,
939
- inputs=[
940
- input_role,
941
- input_topic,
942
- input_context,
943
- input_keywords,
944
- input_length,
945
- input_format,
946
- input_writing_style,
947
- input_tone,
948
- input_user_category,
949
- input_depth,
950
- input_structure,
951
- input_references,
952
- input_num_examples,
953
- input_conclusion,
954
- # ai_generator,
955
- # input_api,
956
- google_search_check,
957
- year_from,
958
- month_from,
959
- day_from,
960
- year_to,
961
- month_to,
962
- day_to,
963
- domains_to_include,
964
- include_sites,
965
- exclude_sites,
966
- pdf_file_input,
967
- history,
968
- ],
969
- outputs=[output_article, history],
970
- )
971
 
972
- regenerate_btn.click(
973
- fn=generate_and_format,
974
- inputs=[
975
- input_role,
976
- input_topic,
977
- input_context,
978
- input_keywords,
979
- input_length,
980
- input_format,
981
- input_writing_style,
982
- input_tone,
983
- input_user_category,
984
- input_depth,
985
- input_structure,
986
- input_references,
987
- input_num_examples,
988
- input_conclusion,
989
- # ai_generator,
990
- # input_api,
991
- google_search_check,
992
- year_from,
993
- month_from,
994
- day_from,
995
- year_to,
996
- month_to,
997
- day_to,
998
- domains_to_include,
999
- pdf_file_input,
1000
- history,
1001
- output_article,
1002
- include_sites,
1003
- exclude_sites,
1004
- ai_comments,
1005
- ],
1006
- outputs=[output_article, history],
1007
- )
1008
 
1009
- ai_check_btn.click(
1010
- fn=ai_check,
1011
- inputs=[output_article, ai_detector_dropdown],
1012
- outputs=[ai_check_result, highlighted_text, mc_check_result],
1013
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1014
 
1015
- humanize_btn.click(
1016
- fn=humanize,
1017
- inputs=[
1018
- output_article,
1019
- model_dropdown,
1020
- temperature_slider,
1021
- repetition_penalty_slider,
1022
- top_k_slider,
1023
- length_penalty_slider,
1024
- history,
1025
- ],
1026
- outputs=[output_article, history],
1027
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1028
 
1029
- generate_btn.click(get_history, inputs=[history], outputs=[history_chat])
1030
- regenerate_btn.click(get_history, inputs=[history], outputs=[history_chat])
1031
- humanize_btn.click(get_history, inputs=[history], outputs=[history_chat])
1032
 
1033
- return demo
1034
 
1035
 
1036
  if __name__ == "__main__":
1037
- demo = create_interface()
1038
- demo.queue(
1039
- max_size=2,
1040
- default_concurrency_limit=2,
1041
- ).launch(server_name="0.0.0.0", share=True, server_port=7890)
1042
- # demo.launch(server_name="0.0.0.0")
 
3
  export GOOGLE_APPLICATION_CREDENTIALS="gcp_creds.json"
4
  """
5
 
6
+ import gc
7
  import re
8
+ import uuid
9
+ import json
10
  from typing import Dict
11
  from collections import defaultdict
12
  from datetime import date, datetime
13
 
 
14
  import nltk
15
  import torch
16
  import numpy as np
17
+ import gradio as gr
18
  import language_tool_python
19
+ from scipy.special import softmax
20
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
21
+ from google.cloud import storage
22
+
23
+ if gr.NO_RELOAD:
24
+ from humanize import humanize_text, device
25
+ from utils import remove_special_characters, split_text_allow_complete_sentences_nltk
26
+ from google_search import google_search, months, domain_list, build_date
27
+ from ai_generate import generate, citations_to_html, remove_citations, display_cited_text, llm_wrapper
28
+
29
+ nltk.download("punkt_tab")
30
+
31
+ print(f"Using device: {device}")
32
+ print("Loading AI detection models...")
33
+ models = {
34
+ "Polygraf AI (Base Model)": AutoModelForSequenceClassification.from_pretrained(
35
+ "polygraf-ai/bc-roberta-openai-2sent"
36
+ ).to(device),
37
+ "Polygraf AI (Advanced Model)": AutoModelForSequenceClassification.from_pretrained(
38
+ "polygraf-ai/bc_combined_3sent"
39
+ ).to(device),
40
+ }
41
+ tokenizers = {
42
+ "Polygraf AI (Base Model)": AutoTokenizer.from_pretrained("polygraf-ai/bc-roberta-openai-2sent"),
43
+ "Polygraf AI (Advanced Model)": AutoTokenizer.from_pretrained("polygraf-ai/bc_combined_3sent"),
44
+ }
45
+
46
+ # grammar correction tool
47
+ tool = language_tool_python.LanguageTool("en-US")
48
+
49
+ # source detection model
50
+ MC_TOKEN_SIZE = 256
51
+ TEXT_MC_MODEL_PATH = "polygraf-ai/mc-model"
52
+ MC_LABEL_MAP = ["OpenAI GPT", "Mistral", "CLAUDE", "Gemini", "Grammar Enhancer"]
53
+ text_mc_tokenizer = AutoTokenizer.from_pretrained(TEXT_MC_MODEL_PATH)
54
+ print("Loading Source detection model...")
55
+ text_mc_model = AutoModelForSequenceClassification.from_pretrained(TEXT_MC_MODEL_PATH).to(device)
56
+
57
+
58
+ def generate_cited_html(cited_text, citations: dict):
59
+ cited_text = cited_text.replace("\n", "<br>")
60
+ html_code = """
61
+ <style>
62
+ .reference-container {
63
+ position: relative;
64
+ display: inline-block;
65
+ }
66
+ .reference-btn {
67
+ display: inline-block;
68
+ width: 20px; /* Reduced width */
69
+ height: 20px; /* Reduced height */
70
+ border-radius: 50%;
71
+ background-color: #e33a89; /* Pink color for the button */
72
+ color: white;
73
+ text-align: center;
74
+ line-height: 20px; /* Adjusted line-height */
75
+ cursor: pointer;
76
+ font-weight: bold;
77
+ margin-right: 5px;
78
+ transition: background-color 0.3s ease, transform 0.3s ease;
79
+ }
80
+ .reference-btn:hover {
81
+ background-color: #ff69b4; /* Lighter pink on hover */
82
+ transform: scale(1.1); /* Slightly enlarge on hover */
83
+ }
84
+ .reference-popup {
85
+ display: none;
86
+ position: absolute;
87
+ z-index: 1;
88
+ top: 100%;
89
+ background-color: #f9f9f9;
90
+ border: 1px solid #ddd;
91
+ padding: 15px;
92
+ border-radius: 4px;
93
+ box-shadow: 0 2px 5px rgba(0,0,0,0.2);
94
+ width: calc(min(90vw, 400px));
95
+ max-height: calc(min(80vh, 300px));
96
+ overflow-y: auto;
97
+ }
98
+ .reference-popup .close-btn {
99
+ float: right;
100
+ cursor: pointer;
101
+ font-weight: bold;
102
+ color: white;
103
+ font-size: 16px;
104
+ padding: 0;
105
+ width: 20px;
106
+ height: 20px;
107
+ text-align: center;
108
+ line-height: 20px;
109
+ background-color: #ff4c4c;
110
+ border-radius: 2px;
111
+ transition: transform 0.3s ease, background-color 0.3s ease;
112
+ }
113
+ .reference-popup .close-btn:hover {
114
+ transform: scale(1.2);
115
+ background-color: #ff3333;
116
+ }
117
+ input[type="radio"] {
118
+ position: absolute;
119
+ opacity: 0;
120
+ pointer-events: none;
121
+ }
122
+ input[type="radio"]:checked + .reference-popup {
123
+ display: block;
124
+ }
125
+
126
+ /* Additional styling for distinct sections */
127
+ .reference-popup strong {
128
+ font-weight: bold;
129
+ color: #333;
130
+ display: block;
131
+ margin-bottom: 5px;
132
+ }
133
+ .reference-popup p {
134
+ margin: 0 0 10px 0;
135
+ padding: 0;
136
+ }
137
+ .reference-popup .source {
138
+ margin-bottom: 10px;
139
+ font-size: 14px;
140
+ font-weight: bold;
141
+ color: #1e90ff;
142
+ }
143
+ .reference-popup .content {
144
+ margin-bottom: 10px;
145
+ font-size: 13px;
146
+ color: #555;
147
+ }
148
 
149
+ @media (prefers-color-scheme: dark) {
150
+ .reference-btn {
151
+ background-color: #1e90ff;
152
+ }
153
+ .reference-popup {
154
+ background-color: #2c2c2c;
155
+ border-color: #444;
156
+ color: #f1f1f1;
157
+ }
158
+ .reference-popup .close-btn {
159
+ background-color: #ff4c4c;
160
+ }
161
+ .reference-popup .close-btn:hover {
162
+ background-color: #ff3333;
163
+ }
164
+ .reference-popup strong {
165
+ color: #ddd;
166
+ }
167
+ .reference-popup .source {
168
+ color: #1e90ff;
169
+ }
170
+ .reference-popup .content {
171
+ color: #bbb;
172
+ }
173
+ }
174
+ </style>
175
+ <script>
176
+ document.addEventListener('click', (event) => {
177
+ const containers = document.querySelectorAll('.reference-container');
178
+ containers.forEach(container => {
179
+ const rect = container.getBoundingClientRect();
180
+ const popup = container.querySelector('.reference-popup');
181
+
182
+ // Reset alignment
183
+ popup.style.left = '';
184
+ popup.style.right = '';
185
+
186
+ const popupWidth = popup.offsetWidth;
187
+ const viewportWidth = window.innerWidth;
188
+
189
+ // If the popup would go off the right edge
190
+ if (rect.right + popupWidth > viewportWidth) {
191
+ popup.style.right = '0'; // Align popup to the right
192
+ }
193
+ // If the popup would go off the left edge
194
+ else if (rect.left - popupWidth < 0) {
195
+ popup.style.left = '0'; // Align popup to the left
196
+ }
197
+ // Otherwise center it
198
+ else {
199
+ popup.style.left = '50%';
200
+ popup.style.transform = 'translateX(-50%)'; // Center the popup
201
+ }
202
+ });
203
+ });
204
+
205
+ function closeReferencePanes() {
206
+ document.querySelectorAll('input[name="reference"]').forEach((input) => {
207
+ input.checked = false;
208
+ });
209
+ }
210
+ </script>
211
+ <div style="height: 600px; overflow-y: auto; overflow-x: auto;">
212
+ """
213
+
214
+ # Function to replace each citation with a reference button
215
+ citation_count = 0 # To track unique instances of each citation
216
 
217
+ def replace_citations(match):
218
+ nonlocal citation_count
219
+ citation_id = match.group(1) # Extract citation number from the match
220
+ ref_data = citations.get(int(citation_id))
221
 
222
+ # If reference data is not found, return the original text
223
+ if not ref_data:
224
+ return match.group(0)
 
 
 
 
 
 
 
 
 
225
 
226
+ # Getting PDF file from gradio path
227
+ if "/var/tmp/gradio/" in ref_data["source"]:
228
+ ref_data["source"] = ref_data["source"].split("/")[-1]
229
+
230
+ # remove new line artifacts from scraping / parsing
231
+ ref_data["content"] = ref_data["content"].replace("\n", " ")
232
+
233
+ # Check if source is a URL, make it clickable if so
234
+ if ref_data["source"].startswith("http"):
235
+ source_html = f'<a href="{ref_data["source"]}" target="_blank" class="source">{ref_data["source"]}</a>'
236
+ else:
237
+ source_html = f'<span class="source">{ref_data["source"]}</span>'
238
+
239
+ # Unique id for each reference button and popup
240
+ unique_id = f"{citation_id}-{citation_count}"
241
+ citation_count += 1
242
+
243
+ # HTML code for the reference button and popup with formatted content
244
+ button_html = f"""
245
+ <span class="reference-container">
246
+ <label for="ref-toggle-{unique_id}" class="reference-btn" onclick="closeReferencePanes(); document.getElementById('ref-toggle-{unique_id}').checked = true;">{int(citation_id)+1}</label>
247
+ <input type="radio" id="ref-toggle-{unique_id}" name="reference" />
248
+ <span class="reference-popup">
249
+ <span class="close-btn" onclick="document.getElementById('ref-toggle-{unique_id}').checked = false;">&times;</span>
250
+ <strong>Source:</strong> {source_html}
251
+ <strong>Content:</strong> <p class="content">{ref_data["content"]}</p>
252
+ </span>
253
+ </span>
254
+ """
255
+ return button_html
256
+
257
+ # Replace inline citations in the text with the generated HTML
258
+ html_code += re.sub(r"<(\d+)>", replace_citations, cited_text)
259
+ html_code += "</div>"
260
+ return html_code
261
 
262
 
263
  # Function to move model to the appropriate device
 
282
  cleaned = re.sub(r"\s+", " ", paragraph).strip()
283
  cleaned = re.sub(r"(?<=\.) ([a-z])", lambda x: x.group(1).upper(), cleaned)
284
  cleaned_paragraphs.append(cleaned)
285
+ cleaned_paragraphs = [item for item in cleaned_paragraphs if item.strip()]
286
+ return "\n\n".join(cleaned_paragraphs)
287
 
288
 
289
  def format_references(text: str) -> str:
 
358
  output = model(**tokens)
359
  output_norm = softmax(output.logits.detach().cpu().numpy(), 1)[0]
360
  output_norm = {"HUMAN": output_norm[0], "AI": output_norm[1]}
361
+ torch.cuda.empty_cache()
362
+ gc.collect()
363
  return output_norm
364
 
365
 
 
419
  ]
420
 
421
 
 
 
 
 
 
 
 
422
  def predict_mc(text):
423
  with torch.no_grad():
424
  text_mc_model.eval()
 
431
  ).to(device)
432
  output = text_mc_model(**tokens)
433
  output_norm = softmax(output.logits.detach().cpu().numpy(), 1)[0]
434
+ torch.cuda.empty_cache()
435
+ gc.collect()
436
  return output_norm
437
 
438
 
 
462
 
463
 
464
  def highlighter_polygraf(text, model="Polygraf AI (Base Model)"):
465
+ text = remove_citations(text)
466
  body, references = split_text_from_refs(text)
467
  score, text = detection_polygraf(text=body, model=model)
468
  mc_score = predict_mc_scores(body, score) # mc score
 
470
  return score, text, mc_score
471
 
472
 
473
+ def ai_check(history: list, option: str):
474
+ text = history[-1][1]
475
  if option.startswith("Polygraf AI"):
476
  return highlighter_polygraf(text, option)
477
  else:
 
479
 
480
 
481
  def generate_prompt(settings: Dict[str, str]) -> str:
482
+ settings["keywords"] = [item for item in settings["keywords"] if item.strip()]
483
+ # - Add a "References" section in the format "References:" on a new line after the requested text, formatted as [1], [2], etc. with each source on their own line
484
  prompt = f"""
485
+ Write a {settings['article_length']} words (around) {settings['format']} on {settings['topic']}.\n
486
+ """
487
+ if settings["context"]:
488
+ prompt += f"""
489
  Context:
490
  - {settings['context']}
491
+ """
492
+ prompt += f"""
493
  Style and Tone:
494
  - Writing style: {settings['writing_style']}
495
  - Tone: {settings['tone']}
496
  - Target audience: {settings['user_category']}
497
+
498
  Content:
499
  - Depth: {settings['depth_of_content']}
500
  - Structure: {', '.join(settings['structure'])}
501
+ """
502
+ if len(settings["keywords"]) > 0:
503
+ prompt += f"""
504
  Keywords to incorporate:
505
  {', '.join(settings['keywords'])}
506
+ """
507
+ prompt += f"""
508
  Additional requirements:
509
  - Don't start with "Here is a...", start with the requested text directly
 
 
510
  - End with a {settings['conclusion_type']} conclusion
 
 
511
  - Do not make any headline, title bold.
512
+ - Ensure proper paragraph breaks for better readability.
513
+ - Avoid any references to artificial intelligence, language models, or the fact that this is generated by an AI, and do not mention something like here is the article etc.
514
+ - Adhere to any format structure provided to the system if any.
515
  """
516
  return prompt
517
 
 
523
  Edit the given text based on user comments.
524
  User Comments:
525
  - {settings['user_comments']}
526
+
527
  Requirements:
528
  - Don't start with "Here is a...", start with the requested text directly
529
  - The original content should not be changed. Make minor modifications based on user comments above.
 
531
  - Do not make any headline, title bold.
532
  Context:
533
  - {settings['context']}
534
+
535
  Ensure proper paragraph breaks for better readability.
536
  Avoid any references to artificial intelligence, language models, or the fact that this is generated by an AI, and do not mention something like here is the article etc.
537
  """
 
585
  prompt = generate_prompt(settings)
586
 
587
  print("Generated Prompt...\n", prompt)
588
+ article, citations = generate(
589
  prompt=prompt,
590
+ input_role=input_role,
591
  topic=topic,
592
+ context=context,
593
  model=ai_model,
594
  url_content=url_content,
595
  path=pdf_file_input,
596
+ # path=["./final_report.pdf"], # TODO: reset
597
  temperature=1,
598
  max_length=2048,
599
  api_key=api_key,
600
  sys_message="",
601
  )
602
+ return article, citations
 
603
 
604
 
605
  def get_history(history):
606
+ # return history
607
+ history_formatted = []
608
+ for entry in history:
609
+ history_formatted.append((entry[0], entry[1]))
610
+ return history_formatted
611
 
612
 
613
  def clear_history():
 
616
 
617
 
618
  def humanize(
 
619
  model: str,
620
+ cited_text: str,
621
  temperature: float = 1.2,
622
  repetition_penalty: float = 1,
623
  top_k: int = 50,
 
625
  history=None,
626
  ) -> str:
627
  print("Humanizing text...")
628
+ # body, references = split_text_from_refs(text)
629
+ cited_text = history[-1][1]
630
+ citations = history[-1][2]
631
+ article = humanize_text(
632
+ text=cited_text,
633
  model_name=model,
634
  temperature=temperature,
635
  repetition_penalty=repetition_penalty,
636
  top_k=top_k,
637
  length_penalty=length_penalty,
638
  )
639
+ # result = result + references
640
+ # corrected_text = format_and_correct_language_check(result)
641
+ article = clean_text(article)
642
  timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
643
+ history.append((f"Humanized Text | {timestamp}\nInput: {model}", article, citations))
644
+ latest_humanizer_data = {
645
+ "original text": cited_text,
646
+ "humanized text": article,
647
+ "citations": citations, # can remove saving citations
648
+ "metadata": {
649
+ "temperature": temperature,
650
+ "repetition_penalty": repetition_penalty,
651
+ "top_k": top_k,
652
+ "length_penalty": length_penalty,
653
+ },
654
+ "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
655
+ }
656
+ return generate_cited_html(article, citations), history, latest_humanizer_data
657
 
658
 
659
  def update_visibility_api(model: str):
 
689
  return gr.update(value=1.0, interactive=True)
690
 
691
 
 
 
 
 
 
692
  # Initialize Google Cloud Storage client
693
  client = storage.Client()
694
  bucket_name = "ai-source-detection"
 
699
  article,
700
  topic,
701
  input_role,
 
702
  context,
703
  keywords,
704
  article_length,
 
731
  "metadata": {
732
  "topic": topic,
733
  "input_role": input_role,
 
734
  "context": context,
735
  "keywords": keywords,
736
  "article_length": article_length,
 
761
  return f"Data saved as {file_name} in GCS."
762
 
763
 
764
+ def save_humanizer_feedback_to_cloud_storage(data, humanizer_feedback):
765
+ """Save generated article and metadata to Google Cloud Storage within a specific folder."""
766
+ if data:
767
+ try:
768
+ data["user_feedback"] = humanizer_feedback
769
+ # Create a unique filename
770
+ file_id = str(uuid.uuid4())
771
+
772
+ # Define the file path and name in the bucket
773
+ folder_path = "ai-writer/humanizer-feedback/"
774
+ file_name = f"{folder_path}{data['timestamp'].replace(' ', '_').replace(':', '-')}_{file_id}.json"
775
+
776
+ # Convert data to JSON string
777
+ json_data = json.dumps(data)
778
+
779
+ # Create a blob and upload to GCS
780
+ blob = bucket.blob(file_name)
781
+ blob.upload_from_string(json_data, content_type="application/json")
782
+ gr.Info("Successfully reported. Thank you for the feedback!")
783
+ except Exception:
784
+ gr.Warning("Report not saved.")
785
+ else:
786
+ gr.Warning("Nothing humanized to save yet!")
787
+
788
+
789
  def generate_and_format(
790
  input_role,
791
  topic,
 
823
  date_from = build_date(year_from, month_from, day_from)
824
  date_to = build_date(year_to, month_to, day_to)
825
  sorted_date = f"date:r:{date_from}:{date_to}"
826
+ final_query = llm_wrapper(
827
+ input_role, topic, context, model="OpenAI GPT 4o", task_type="internet", temperature=0.7
828
+ )
829
  if include_sites:
830
  site_queries = [f"site:{site.strip()}" for site in include_sites.split(",")]
831
  final_query += " " + " OR ".join(site_queries)
 
834
  final_query += " " + " ".join(exclude_queries)
835
  print(f"Google Search Query: {final_query}")
836
  url_content = google_search(final_query, sorted_date, domains_to_include)
837
+ # topic_context = topic + ", " + context
838
+ article, citations = generate_article(
839
  input_role,
840
+ topic,
841
  context,
842
  keywords,
843
  article_length,
 
857
  generated_article,
858
  user_comments,
859
  )
860
+ # if ends_with_references(article) and url_content is not None:
861
+ # for url in url_content.keys():
862
+ # article += f"\n{url}"
863
 
864
+ article = clean_text(display_cited_text(article))
865
+ # reference_formatted = format_references(article)
866
  timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
867
+ history.append((f"Generated Text | {timestamp}\nInput: {topic}", article, citations))
868
 
869
  # Save the article and metadata to Cloud Storage
870
  # We dont save if there is PDF input for privacy reasons
 
873
  article,
874
  topic,
875
  input_role,
 
876
  context,
877
  keywords,
878
  article_length,
 
892
  timestamp,
893
  )
894
  print(save_message)
895
+ return generate_cited_html(article, citations), history
896
+
897
+
898
+ # def create_interface():
899
+ with gr.Blocks(
900
+ theme=gr.themes.Default(
901
+ primary_hue=gr.themes.colors.pink, secondary_hue=gr.themes.colors.yellow, neutral_hue=gr.themes.colors.gray
902
+ ),
903
+ css="""
904
+ .input-highlight-pink block_label {background-color: #008080}
905
+ """,
906
+ ) as demo:
907
+ history = gr.State([])
908
+ latest_humanizer_data = gr.State()
909
+ today = date.today()
910
+ # dd/mm/YY
911
+ d1 = today.strftime("%d/%B/%Y")
912
+ d1 = d1.split("/")
913
+ gr.Markdown("# Polygraf AI Content Writer", elem_classes="text-center text-3xl mb-6")
914
+
915
+ with gr.Row():
916
+ with gr.Column(scale=1):
917
+ with gr.Group():
918
+ gr.Markdown("## Article Configuration", elem_classes="text-xl mb-4")
919
+ input_role = gr.Textbox(label="I am a", placeholder="Enter your role", value="Student")
920
+ input_topic = gr.Textbox(
921
+ label="Topic",
922
+ placeholder="Enter the main topic of your article",
923
+ elem_classes="input-highlight-pink",
924
+ )
925
+ input_context = gr.Textbox(
926
+ label="Context",
927
+ placeholder="Provide some context for your topic",
928
+ elem_classes="input-highlight-pink",
929
+ )
930
+ input_keywords = gr.Textbox(
931
+ label="Keywords",
932
+ placeholder="Enter comma-separated keywords",
933
+ elem_classes="input-highlight-yellow",
934
+ )
935
+
936
+ with gr.Row():
937
+ input_format = gr.Dropdown(
938
+ choices=[
939
+ "Article",
940
+ "Essay",
941
+ "Blog post",
942
+ "Report",
943
+ "Research paper",
944
+ "News article",
945
+ "White paper",
946
+ "Email",
947
+ "LinkedIn post",
948
+ "X (Twitter) post",
949
+ "Instagram Video Content",
950
+ "TikTok Video Content",
951
+ "Facebook post",
952
+ ],
953
+ value="Article",
954
+ label="Format",
955
+ elem_classes="input-highlight-turquoise",
956
  )
957
+
958
+ input_length = gr.Slider(
959
+ minimum=50,
960
+ maximum=5000,
961
+ step=50,
962
+ value=300,
963
+ label="Article Length",
964
+ elem_classes="input-highlight-pink",
965
+ )
966
+
967
+ with gr.Row():
968
+ input_writing_style = gr.Dropdown(
969
+ choices=[
970
+ "Formal",
971
+ "Informal",
972
+ "Technical",
973
+ "Conversational",
974
+ "Journalistic",
975
+ "Academic",
976
+ "Creative",
977
+ ],
978
+ value="Formal",
979
+ label="Writing Style",
980
  elem_classes="input-highlight-yellow",
981
  )
982
+ input_tone = gr.Dropdown(
983
+ choices=["Friendly", "Professional", "Neutral", "Enthusiastic", "Skeptical", "Humorous"],
984
+ value="Professional",
985
+ label="Tone",
986
+ elem_classes="input-highlight-turquoise",
987
+ )
988
 
989
+ input_user_category = gr.Dropdown(
990
+ choices=[
991
+ "Students",
992
+ "Professionals",
993
+ "Researchers",
994
+ "General Public",
995
+ "Policymakers",
996
+ "Entrepreneurs",
997
+ ],
998
+ value="General Public",
999
+ label="Target Audience",
1000
+ elem_classes="input-highlight-pink",
1001
+ )
1002
+ input_depth = gr.Dropdown(
1003
+ choices=[
1004
+ "Surface-level overview",
1005
+ "Moderate analysis",
1006
+ "In-depth research",
1007
+ "Comprehensive study",
1008
+ ],
1009
+ value="Moderate analysis",
1010
+ label="Depth of Content",
1011
+ elem_classes="input-highlight-yellow",
1012
+ )
1013
+ input_structure = gr.Dropdown(
1014
+ choices=[
1015
+ "Introduction, Body, Conclusion",
1016
+ "Abstract, Introduction, Methods, Results, Discussion, Conclusion",
1017
+ "Executive Summary, Problem Statement, Analysis, Recommendations, Conclusion",
1018
+ "Introduction, Literature Review, Methodology, Findings, Analysis, Conclusion",
1019
+ "Plain Text",
1020
+ ],
1021
+ value="Introduction, Body, Conclusion",
1022
+ label="Structure",
1023
+ elem_classes="input-highlight-turquoise",
1024
+ interactive=True,
1025
+ )
1026
+ input_references = gr.Dropdown(
1027
+ choices=[
1028
+ "Academic journals",
1029
+ "Industry reports",
1030
+ "Government publications",
1031
+ "News outlets",
1032
+ "Expert interviews",
1033
+ "Case studies",
1034
+ ],
1035
+ value="News outlets",
1036
+ label="References",
1037
+ elem_classes="input-highlight-pink",
1038
+ )
1039
+ input_num_examples = gr.Dropdown(
1040
+ choices=["1-2", "3-4", "5+"],
1041
+ value="1-2",
1042
+ label="Number of Examples/Case Studies",
1043
+ elem_classes="input-highlight-yellow",
1044
+ )
1045
+ input_conclusion = gr.Dropdown(
1046
+ choices=["Summary", "Call to Action", "Future Outlook", "Thought-provoking Question"],
1047
+ value="Call to Action",
1048
+ label="Conclusion Type",
1049
+ elem_classes="input-highlight-turquoise",
1050
+ )
1051
+ gr.Markdown("# Search Options", elem_classes="text-center text-3xl mb-6")
1052
+ google_default = False
1053
+ with gr.Row():
1054
+ google_search_check = gr.Checkbox(
1055
+ label="Enable Internet Search For Recent Sources", value=google_default
1056
+ )
1057
+ with gr.Group(visible=google_default) as search_options:
1058
  with gr.Row():
1059
+ include_sites = gr.Textbox(
1060
+ label="Include Specific Websites",
1061
+ placeholder="Enter comma-separated keywords",
1062
+ elem_classes="input-highlight-yellow",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1063
  )
 
 
 
 
 
 
 
 
 
 
1064
  with gr.Row():
1065
+ exclude_sites = gr.Textbox(
1066
+ label="Exclude Specific Websites",
1067
+ placeholder="Enter comma-separated keywords",
 
 
 
 
 
 
 
 
 
1068
  elem_classes="input-highlight-yellow",
1069
  )
1070
+ with gr.Row():
1071
+ domains_to_include = gr.Dropdown(
1072
+ domain_list,
1073
+ value=domain_list,
1074
+ multiselect=True,
1075
+ label="Domains To Include",
1076
+ )
1077
+ with gr.Row():
1078
+ month_from = gr.Dropdown(
1079
+ choices=months,
1080
+ label="From Month",
1081
+ value="January",
1082
+ interactive=True,
1083
  )
1084
+ day_from = gr.Textbox(label="From Day", value="01")
1085
+ year_from = gr.Textbox(label="From Year", value="2000")
1086
 
1087
+ with gr.Row():
1088
+ month_to = gr.Dropdown(
1089
+ choices=months,
1090
+ label="To Month",
1091
+ value=d1[1],
1092
+ interactive=True,
1093
+ )
1094
+ day_to = gr.Textbox(label="To Day", value=d1[0])
1095
+ year_to = gr.Textbox(label="To Year", value=d1[2])
1096
+
1097
+ gr.Markdown("# Add Optional PDF Files with Information", elem_classes="text-center text-3xl mb-6")
1098
+ pdf_file_input = gr.File(label="Upload PDF(s)", file_count="multiple", file_types=[".pdf"])
1099
+ """
1100
+ # NOTE: HIDE AI MODEL SELECTION
1101
+ with gr.Group():
1102
+ gr.Markdown("## AI Model Configuration", elem_classes="text-xl mb-4")
1103
+ ai_generator = gr.Dropdown(
1104
+ choices=[
1105
+ "OpenAI GPT 4",
1106
+ "OpenAI GPT 4o",
1107
+ "OpenAI GPT 4o Mini",
1108
+ "Claude Sonnet 3.5",
1109
+ "Gemini 1.5 Pro",
1110
+ "LLaMA 3",
1111
+ ],
1112
+ value="OpenAI GPT 4o Mini",
1113
+ label="AI Model",
1114
+ elem_classes="input-highlight-pink",
1115
+ )
1116
+ input_api = gr.Textbox(label="API Key", visible=False)
1117
+ ai_generator.change(update_visibility_api, ai_generator, input_api)
1118
+ """
1119
+ generate_btn = gr.Button("Generate Article", variant="primary")
1120
+
1121
+ with gr.Column(scale=2):
1122
+ with gr.Tab("Text Generator"):
1123
+ output_article = gr.HTML(
1124
+ value="""<div style="height: 600px;"></div>""",
1125
+ label="Generated Article",
1126
+ )
1127
+ with gr.Accordion("Regenerate Article", open=False):
1128
+ ai_comments = gr.Textbox(
1129
+ label="Add comments to help edit generated text", interactive=True, visible=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1130
  )
1131
+ regenerate_btn = gr.Button("Regenerate Article", variant="primary", visible=True)
1132
+
1133
+ ai_detector_dropdown = gr.Dropdown(
1134
+ choices=ai_check_options, label="Select AI Detector", value="Polygraf AI (Base Model)"
1135
+ )
1136
+ ai_check_btn = gr.Button("AI Check")
1137
+
1138
+ with gr.Accordion("AI Detection Results", open=True):
1139
+ ai_check_result = gr.Label(label="AI Check Result")
1140
+ mc_check_result = gr.Label(label="Creator Check Result")
1141
+ highlighted_text = gr.HTML(label="Sentence Breakdown", visible=False)
1142
+
1143
+ with gr.Accordion("Advanced Humanizer Settings", open=False):
1144
+ with gr.Row():
1145
+ model_dropdown = gr.Radio(
1146
+ choices=["Standard Model", "Advanced Model (Beta)"],
1147
+ value="Advanced Model (Beta)",
1148
+ label="Humanizer Model Version",
1149
+ )
1150
  with gr.Row():
1151
+ temperature_slider = gr.Slider(
1152
+ minimum=0.5, maximum=2.0, step=0.1, value=1.0, label="Temperature"
1153
  )
1154
+ top_k_slider = gr.Slider(minimum=0, maximum=300, step=25, value=40, label="Top k")
1155
+ with gr.Row():
1156
+ repetition_penalty_slider = gr.Slider(
1157
+ minimum=1.0, maximum=2.0, step=0.1, value=1, label="Repetition Penalty"
1158
+ )
1159
+ length_penalty_slider = gr.Slider(
1160
+ minimum=0.0, maximum=2.0, step=0.1, value=1.0, label="Length Penalty"
1161
+ )
1162
+
1163
+ humanize_btn = gr.Button("Humanize")
1164
+ with gr.Row(equal_height=False):
1165
+ with gr.Column():
1166
+ humanizer_feedback = gr.Textbox(label="Add optional feedback on humanizer")
1167
+ with gr.Column():
1168
+ report_humanized_btn = gr.Button("Report Humanized Text", variant="primary", visible=True)
1169
+ # humanized_output = gr.Markdown(label="Humanized Article", value="\n\n\n\n", render=True)
1170
+ # copy_to_input_btn = gr.Button("Copy to Input for AI Check")
1171
+
1172
+ with gr.Tab("History"):
1173
+ history_chat = gr.Chatbot(label="Generation History", height=1000)
1174
+ clear_history_btn = gr.Button("Clear History")
1175
+ clear_history_btn.click(clear_history, outputs=[history, history_chat])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1176
  """
1177
+ # NOTE: REMOVED REFRESH BUTTON
1178
+ refresh_button = gr.Button("Refresh History")
1179
+ refresh_button.click(get_history, outputs=history_chat)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1180
  """
 
1181
 
1182
+ def regenerate_visible(text):
1183
+ if text:
1184
+ return gr.update(visible=True)
1185
+ else:
1186
+ return gr.update(visible=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1187
 
1188
+ def highlight_visible(text):
1189
+ if text.startswith("Polygraf"):
1190
+ return gr.update(visible=True)
1191
+ else:
1192
+ return gr.update(visible=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1193
 
1194
+ def search_visible(toggle):
1195
+ if toggle:
1196
+ return gr.update(visible=True)
1197
+ else:
1198
+ return gr.update(visible=False)
1199
+
1200
+ google_search_check.change(search_visible, inputs=google_search_check, outputs=search_options)
1201
+ # ai_detector_dropdown.change(highlight_visible, inputs=ai_detector_dropdown, outputs=highlighted_text)
1202
+ # output_article.change(regenerate_visible, inputs=output_article, outputs=ai_comments)
1203
+ # ai_comments.change(regenerate_visible, inputs=output_article, outputs=regenerate_btn)
1204
+ ai_check_btn.click(highlight_visible, inputs=ai_detector_dropdown, outputs=highlighted_text)
1205
+
1206
+ # Update the default structure based on the selected format
1207
+ # e.g. "Plain Text" for certain formats
1208
+ input_format.change(fn=update_structure, inputs=input_format, outputs=input_structure)
1209
+ model_dropdown.change(fn=update_temperature, inputs=model_dropdown, outputs=temperature_slider)
1210
+ report_humanized_btn.click(
1211
+ save_humanizer_feedback_to_cloud_storage, inputs=[latest_humanizer_data, humanizer_feedback]
1212
+ )
1213
 
1214
+ generate_btn.click(
1215
+ fn=generate_and_format,
1216
+ inputs=[
1217
+ input_role,
1218
+ input_topic,
1219
+ input_context,
1220
+ input_keywords,
1221
+ input_length,
1222
+ input_format,
1223
+ input_writing_style,
1224
+ input_tone,
1225
+ input_user_category,
1226
+ input_depth,
1227
+ input_structure,
1228
+ input_references,
1229
+ input_num_examples,
1230
+ input_conclusion,
1231
+ # ai_generator,
1232
+ # input_api,
1233
+ google_search_check,
1234
+ year_from,
1235
+ month_from,
1236
+ day_from,
1237
+ year_to,
1238
+ month_to,
1239
+ day_to,
1240
+ domains_to_include,
1241
+ include_sites,
1242
+ exclude_sites,
1243
+ pdf_file_input,
1244
+ history,
1245
+ ],
1246
+ outputs=[output_article, history],
1247
+ )
1248
+
1249
+ regenerate_btn.click(
1250
+ fn=generate_and_format,
1251
+ inputs=[
1252
+ input_role,
1253
+ input_topic,
1254
+ input_context,
1255
+ input_keywords,
1256
+ input_length,
1257
+ input_format,
1258
+ input_writing_style,
1259
+ input_tone,
1260
+ input_user_category,
1261
+ input_depth,
1262
+ input_structure,
1263
+ input_references,
1264
+ input_num_examples,
1265
+ input_conclusion,
1266
+ # ai_generator,
1267
+ # input_api,
1268
+ google_search_check,
1269
+ year_from,
1270
+ month_from,
1271
+ day_from,
1272
+ year_to,
1273
+ month_to,
1274
+ day_to,
1275
+ domains_to_include,
1276
+ pdf_file_input,
1277
+ history,
1278
+ output_article,
1279
+ include_sites,
1280
+ exclude_sites,
1281
+ ai_comments,
1282
+ ],
1283
+ outputs=[output_article, history],
1284
+ )
1285
+
1286
+ ai_check_btn.click(
1287
+ fn=ai_check,
1288
+ inputs=[history, ai_detector_dropdown],
1289
+ outputs=[ai_check_result, highlighted_text, mc_check_result],
1290
+ )
1291
+
1292
+ humanize_btn.click(
1293
+ fn=humanize,
1294
+ inputs=[
1295
+ model_dropdown,
1296
+ output_article,
1297
+ temperature_slider,
1298
+ repetition_penalty_slider,
1299
+ top_k_slider,
1300
+ length_penalty_slider,
1301
+ history,
1302
+ ],
1303
+ outputs=[output_article, history, latest_humanizer_data],
1304
+ )
1305
 
1306
+ generate_btn.click(get_history, inputs=[history], outputs=[history_chat])
1307
+ regenerate_btn.click(get_history, inputs=[history], outputs=[history_chat])
1308
+ humanize_btn.click(get_history, inputs=[history], outputs=[history_chat])
1309
 
1310
+ # return demo
1311
 
1312
 
1313
  if __name__ == "__main__":
1314
+ # demo = create_interface()
1315
+ # demo.queue(
1316
+ # max_size=2,
1317
+ # default_concurrency_limit=2,
1318
+ # ).launch(server_name="0.0.0.0", share=True, server_port=7890)
1319
+ demo.launch(server_name="0.0.0.0")
humanize.py CHANGED
@@ -3,8 +3,9 @@ import torch
3
  import nltk
4
  from nltk import sent_tokenize
5
  import gradio as gr
6
- from peft import PeftModel
7
  from transformers import T5ForConditionalGeneration, T5Tokenizer
 
 
8
 
9
  nltk.download("punkt")
10
 
@@ -49,7 +50,34 @@ FastLanguageModel.for_inference(dec_only_model) # native 2x faster inference
49
  print(f"Loaded model: {dec_only}, Num. params: {dec_only_model.num_parameters()}")
50
 
51
 
52
- def humanize_batch_seq2seq(model, tokenizer, sentences, temperature, repetition_penalty, top_k, length_penalty):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  inputs = ["Please paraphrase this sentence: " + sentence for sentence in sentences]
54
  inputs = tokenizer(inputs, return_tensors="pt", padding=True, truncation=True).to(model.device)
55
  outputs = model.generate(
@@ -65,7 +93,15 @@ def humanize_batch_seq2seq(model, tokenizer, sentences, temperature, repetition_
65
  return answers
66
 
67
 
68
- def humanize_batch_decoder_only(model, tokenizer, sentences, temperature, repetition_penalty, top_k, length_penalty):
 
 
 
 
 
 
 
 
69
  pre_prompt = "As a humanizer model, your task is to rewrite the following sentence to make it more human-like. Return only the paraphrased sentence. \n\n"
70
  # Construct the messages_batch using the tokenized sentences
71
  messages_batch = [{"from": "human", "value": f"{pre_prompt}{sentence}"} for sentence in sentences]
@@ -73,7 +109,12 @@ def humanize_batch_decoder_only(model, tokenizer, sentences, temperature, repeti
73
  tokenizer = get_chat_template(
74
  tokenizer,
75
  chat_template="phi-3",
76
- mapping={"role": "from", "content": "value", "user": "human", "assistant": "gpt"}, # ShareGPT style
 
 
 
 
 
77
  )
78
 
79
  # Enable native 2x faster inference
@@ -130,9 +171,11 @@ def humanize_text(
130
  Paragraphs are stored as a number of sentences per paragraph.
131
  """
132
  progress(0, desc="Starting to Humanize")
133
-
134
  # Map model names to their respective processing functions
135
- model_map = {"Standard Model": humanize_batch_seq2seq, "Advanced Model (Beta)": humanize_batch_decoder_only}
 
 
 
136
  assert model_name in model_map, f"Invalid model name: {model_name}"
137
  process_function = model_map[model_name]
138
 
@@ -140,7 +183,10 @@ def humanize_text(
140
  paragraphs = text.split("\n")
141
  all_sentences = []
142
  sentences_per_paragraph = []
 
143
  for paragraph in paragraphs:
 
 
144
  sentences = sent_tokenize(paragraph)
145
  sentences_per_paragraph.append(len(sentences))
146
  all_sentences.extend(sentences)
@@ -156,8 +202,8 @@ def humanize_text(
156
 
157
  # Call the selected processing function
158
  paraphrased_batch = process_function(
159
- seq2seq_model if model_name == "Standard Model" else dec_only_model,
160
- seq2seq_tokenizer if model_name == "Standard Model" else dec_only_tokenizer,
161
  batch_sentences,
162
  temperature,
163
  repetition_penalty,
@@ -188,6 +234,8 @@ def humanize_text(
188
  humanized_paragraph = " ".join(paraphrased_sentences[sentence_index : sentence_index + num_sentences])
189
  humanized_paragraphs.append(humanized_paragraph)
190
  sentence_index += num_sentences
191
-
192
- humanized_text = "\n".join(humanized_paragraphs)
 
 
193
  return humanized_text
 
3
  import nltk
4
  from nltk import sent_tokenize
5
  import gradio as gr
 
6
  from transformers import T5ForConditionalGeneration, T5Tokenizer
7
+ import language_tool_python
8
+ import re
9
 
10
  nltk.download("punkt")
11
 
 
50
  print(f"Loaded model: {dec_only}, Num. params: {dec_only_model.num_parameters()}")
51
 
52
 
53
+ # grammar correction tool
54
+ tool = language_tool_python.LanguageTool("en-US")
55
+
56
+
57
+ def format_and_correct_language_check(text: str) -> str:
58
+ return tool.correct(text)
59
+
60
+
61
+ def extract_citations(text):
62
+ citations = re.findall(r"<(\d+)>", text)
63
+ return [int(citation) for citation in citations]
64
+
65
+
66
+ def remove_citations(text):
67
+ text = re.sub(r"<\d+>", "", text)
68
+ text = re.sub(r"[\d+]", "", text)
69
+ return text
70
+
71
+
72
+ def humanize_batch_seq2seq(
73
+ model,
74
+ tokenizer,
75
+ sentences,
76
+ temperature,
77
+ repetition_penalty,
78
+ top_k,
79
+ length_penalty,
80
+ ):
81
  inputs = ["Please paraphrase this sentence: " + sentence for sentence in sentences]
82
  inputs = tokenizer(inputs, return_tensors="pt", padding=True, truncation=True).to(model.device)
83
  outputs = model.generate(
 
93
  return answers
94
 
95
 
96
+ def humanize_batch_decoder_only(
97
+ model,
98
+ tokenizer,
99
+ sentences,
100
+ temperature,
101
+ repetition_penalty,
102
+ top_k,
103
+ length_penalty,
104
+ ):
105
  pre_prompt = "As a humanizer model, your task is to rewrite the following sentence to make it more human-like. Return only the paraphrased sentence. \n\n"
106
  # Construct the messages_batch using the tokenized sentences
107
  messages_batch = [{"from": "human", "value": f"{pre_prompt}{sentence}"} for sentence in sentences]
 
109
  tokenizer = get_chat_template(
110
  tokenizer,
111
  chat_template="phi-3",
112
+ mapping={
113
+ "role": "from",
114
+ "content": "value",
115
+ "user": "human",
116
+ "assistant": "gpt",
117
+ }, # ShareGPT style
118
  )
119
 
120
  # Enable native 2x faster inference
 
171
  Paragraphs are stored as a number of sentences per paragraph.
172
  """
173
  progress(0, desc="Starting to Humanize")
 
174
  # Map model names to their respective processing functions
175
+ model_map = {
176
+ "Standard Model": humanize_batch_seq2seq,
177
+ "Advanced Model (Beta)": humanize_batch_decoder_only,
178
+ }
179
  assert model_name in model_map, f"Invalid model name: {model_name}"
180
  process_function = model_map[model_name]
181
 
 
183
  paragraphs = text.split("\n")
184
  all_sentences = []
185
  sentences_per_paragraph = []
186
+ citations_per_paragraph = []
187
  for paragraph in paragraphs:
188
+ citations_per_paragraph.append(extract_citations(paragraph))
189
+ paragraph = remove_citations(paragraph)
190
  sentences = sent_tokenize(paragraph)
191
  sentences_per_paragraph.append(len(sentences))
192
  all_sentences.extend(sentences)
 
202
 
203
  # Call the selected processing function
204
  paraphrased_batch = process_function(
205
+ (seq2seq_model if model_name == "Standard Model" else dec_only_model),
206
+ (seq2seq_tokenizer if model_name == "Standard Model" else dec_only_tokenizer),
207
  batch_sentences,
208
  temperature,
209
  repetition_penalty,
 
234
  humanized_paragraph = " ".join(paraphrased_sentences[sentence_index : sentence_index + num_sentences])
235
  humanized_paragraphs.append(humanized_paragraph)
236
  sentence_index += num_sentences
237
+ for i, paragraph in enumerate(humanized_paragraphs):
238
+ citation_texts = [f"<{cid}>" for cid in citations_per_paragraph[i]]
239
+ humanized_paragraphs[i] = paragraph + " " + "".join(citation_texts)
240
+ humanized_text = "\n\n".join(humanized_paragraphs)
241
  return humanized_text
requirements.txt CHANGED
@@ -24,4 +24,5 @@ langchain-google-genai
24
  langchain-anthropic
25
  langchain-openai
26
  vertexai
27
- html2text
 
 
24
  langchain-anthropic
25
  langchain-openai
26
  vertexai
27
+ html2text
28
+ bm25s