File size: 5,020 Bytes
5dcba42
 
ee6d38b
 
 
 
 
5dcba42
 
ee6d38b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5dcba42
 
 
 
 
 
 
 
 
 
 
 
 
ee6d38b
 
 
 
 
 
 
5dcba42
 
 
 
 
 
 
 
 
 
 
 
ee6d38b
5dcba42
ee6d38b
e4a0928
ee6d38b
 
 
 
 
 
e4a0928
 
ee6d38b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69383c7
 
e4a0928
 
 
 
 
 
 
 
ee6d38b
 
 
 
 
 
 
 
 
 
 
 
e4a0928
5dcba42
ee6d38b
 
5dcba42
 
 
 
 
 
ee6d38b
 
 
 
5dcba42
ee6d38b
e4a0928
ee6d38b
5dcba42
ee6d38b
 
 
 
 
 
 
 
 
5dcba42
 
 
 
ee6d38b
5dcba42
ee6d38b
 
5dcba42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ee6d38b
 
 
 
5dcba42
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import base64
import io
import json
import os
import re
from functools import cache

from litellm import completion
from pydantic import BaseModel

try:
    from dotenv import load_dotenv

    load_dotenv()
except:
    pass

generation_config = {
    "temperature": 0.9,  # Temperature of the sampling distribution
    "top_p": 1,  # Probability of sampling from the top p tokens
    "top_k": 1,  # Number of top tokens to sample from
    "max_output_tokens": 2048,
}


class TextEdits(BaseModel):
    term: str
    start_char: int
    end_char: int
    type: str
    fix: str
    reason: str

class SuggestedEdits(BaseModel):
    edits: list[TextEdits]


safety_settings = [
    {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
    {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
    {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
    {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_ONLY_HIGH"},
]

gemini_models = [
    {
        "name": "Gemini 2.0 Flash",
        "model": "gemini/gemini-2.0-flash",
        "image_support": True,
    },
    {
        "name": "Gemini 1.5 Pro",
        "model": "gemini/gemini-1.5-pro",
        "image_support": False,
    },
]

models_dict = {model["name"]: model for model in gemini_models}


@cache
def get_file(relative_path: str) -> str:
    current_path = os.path.dirname(os.path.abspath(__file__))
    full_path = os.path.join(current_path, relative_path)
    with open(full_path) as f:
        return f.read()


def html_title(title: str) -> str:
    return f"<h1>{title}</h1>"


def apply_review(text: str, review: list[dict]) -> str:
    output = ""
    review = sorted(review, key=lambda x: x["start_char"])
    last_end = 0
    for entity in review:
        starts = [
            m.start() + last_end
            for m in re.finditer(entity["term"].lower(), text[last_end:].lower())
        ]
        if len(starts) > 0:
            start = starts[0]
            end = start + len(entity["term"])
            output += text[last_end:start]
            if "fix" not in entity:
                entity["fix"] = ""
            if len(entity["fix"]) > 0:
                output += get_file("templates/correction.html").format(
                    term=text[start:end], fix=entity["fix"], kind=entity["type"]
                )
            else:
                output += get_file("templates/deletion.html").format(
                    term=text[start:end], kind=entity["type"]
                )
            last_end = end
    output += text[last_end:]
    return f"<pre style='white-space: pre-wrap;'>{output}</pre>"


def review_table_summary(review: list[dict]) -> str:
    table = "<table><tr><th>Term</th><th>Fix</th><th>Type</th><th>Reason</th></tr>"
    for entity in review:
        table += f"<tr><td>{entity['term']}</td><td>{entity['fix']}</td><td>{entity['type']}</td><td>{entity.get('reason', '-')}</td></tr>"
    table += "</table>"
    return table


def review_text(model: str, text: str) -> list[dict]:
    template = get_file("templates/prompt_v1.txt")
    try:
        response = completion(
            model=model,
            messages=[{"role": "user", "content": template.format(text=text)}],
            response_format=SuggestedEdits,
        )
    except Exception as e:
        print(e)
        raise ValueError(
            f"Error while getting answer from the model, make sure the content isn't offensive or dangerous."
        )
    return json.loads(response.choices[0].message.content)["edits"]


def process_text(model: str, text: str) -> str:
    review = review_text(models_dict[model]["model"], text)
    if len(review) == 0:
        return html_title("No issues found in the text 🎉🎉🎉")
    return (
        html_title("Reviewed text")
        + apply_review(text, review)
        + html_title("Explanation")
        + review_table_summary(review)
    )

def image_to_base64_string(img):
    buffered = io.BytesIO()
    img.save(buffered, format="JPEG")
    return base64.b64encode(buffered.getvalue()).decode("utf-8")

def process_image(model: str, image) -> list[dict]:
    prompt = get_file("templates/prompt_image_v1.txt")
    try:
        response = completion(
            model=models_dict[model]["model"],
            messages=[
                {
                    "role": "user",
                    "content": [
                        {"type": "text", "text": prompt},
                        {
                            "type": "image_url",
                            "image_url": "data:image/jpeg;base64," + image_to_base64_string(image),
                        },
                    ],
                }
            ],
        )
    except ValueError as e:
        print(e)
        message = f"Error while getting answer from the model, make sure the content isn't offensive or dangerous. Please try again or change the prompt. {str(e)}"
        raise ValueError(message)
    return response.choices[0].message.content