Write4All / app.py
Jaume's picture
pre text wrap
f0002f1
raw
history blame
5.74 kB
import re
import os
import gradio as gr
import json
from functools import cache
import google.generativeai as genai
try:
from dotenv import load_dotenv
load_dotenv()
except:
pass
generation_config = {
"temperature": 0.9, # Temperature of the sampling distribution
"top_p": 1, # Probability of sampling from the top p tokens
"top_k": 1, # Number of top tokens to sample from
"max_output_tokens": 2048,
}
safety_settings = [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_ONLY_HIGH"},
]
genai.configure(api_key=os.getenv("GEMINI_API_KEY"))
text_model = genai.GenerativeModel(
model_name="gemini-1.0-pro",
generation_config=generation_config,
safety_settings=safety_settings,
)
vision_model = genai.GenerativeModel(
"gemini-pro-vision",
generation_config=generation_config,
safety_settings=safety_settings,
)
@cache
def get_file(path: str) -> str:
with open(path) as f:
return f.read()
def fix_json(json_str: str) -> str:
template = get_file("templates/prompt_json_fix.txt")
prompt = template.format(json=json_str)
response = text_model.generate_content(prompt).text
return response.split("```json")[1].split("```")[0]
def get_json_content(response: str) -> dict:
print(response)
if "```json" not in response:
return []
raw_json = response.split("```json")[1].split("```")[0]
try:
return json.loads(raw_json)
except json.JSONDecodeError as e:
print(e)
new_json = fix_json(raw_json)
print(new_json)
return json.loads(new_json)
def review_text(text: str) -> list[dict]:
template = get_file("templates/prompt_v1.txt")
try:
response = text_model.generate_content(template.format(text=text)).text
except ValueError as e:
print(e)
raise ValueError(
f"Error while getting answer from the model, make sure the content isn't offensive or dangerous."
)
return get_json_content(response)
def review_image(image) -> list[dict]:
prompt = get_file("templates/prompt_image_v1.txt")
try:
response = vision_model.generate_content([prompt, image]).text
except ValueError as e:
print(e)
message = "Error while getting answer from the model, make sure the content isn't offensive or dangerous. Please try again or change the prompt."
gr.Error(message)
raise ValueError(message)
return response
def html_title(title: str) -> str:
return f"<h1>{title}</h1>"
def apply_review(text: str, review: list[dict]) -> str:
output = ""
review = sorted(review, key=lambda x: x["start_char"])
last_end = 0
for entity in review:
starts = [
m.start() + last_end
for m in re.finditer(entity["term"].lower(), text[last_end:].lower())
]
if len(starts) > 0:
start = starts[0]
end = start + len(entity["term"])
output += text[last_end:start]
output += get_file("templates/correction.html").format(
term=text[start:end], fix=entity["fix"], kind=entity["type"]
)
last_end = end
output += text[last_end:]
return f"<pre style='white-space: pre-wrap;'>{output}</pre>"
def review_table_summary(review: list[dict]) -> str:
table = "<table><tr><th>Term</th><th>Fix</th><th>Type</th><th>Reason</th></tr>"
for entity in review:
table += f"<tr><td>{entity['term']}</td><td>{entity['fix']}</td><td>{entity['type']}</td><td>{entity.get('reason', '-')}</td></tr>"
table += "</table>"
return table
def format_entities(text: str, review: list[dict]) -> list[dict]:
entities = []
for entity in review:
# Find all occurrences of the term in the text
starts = [m.start() for m in re.finditer(entity["term"], text)]
if len(starts) > 0:
entities.append(
{
"term": entity["term"],
"start": starts[0],
"end": starts[0] + len(entity["term"]),
"entity": entity["type"],
"fix": entity["fix"],
}
)
else:
print(f"Term '{entity['term']}' not found in the text: '{text}'")
return entities
def process_text(text):
review = review_text(text)
if len(review) == 0:
return html_title("No issues found in the text 🎉🎉🎉")
return (
html_title("Reviewed text")
+ apply_review(text, review)
+ html_title("Explanation")
+ review_table_summary(review)
)
def process_image(image):
print(image)
return review_image(image)
text_ui = gr.Interface(
fn=process_text,
inputs=["text"],
outputs=[gr.HTML(label="Revision")],
examples=[
"The whitelist is incomplete.",
"There's not enough manpower to deliver the project",
"This has never happened in the history of mankind!",
"El hombre desciende del mono.",
"Els homes són animals",
],
)
image_ui = gr.Interface(
fn=process_image,
inputs=gr.Image(sources=["upload", "clipboard"], type="pil"),
outputs=["markdown"],
examples=["static/images/CEOs.png", "static/images/meat_grid.png"],
)
with gr.Blocks() as demo:
gr.Markdown(get_file("static/intro.md"))
gr.TabbedInterface([text_ui, image_ui], ["Check texts", "Check images"])
if __name__ == "__main__":
demo.launch()