File size: 7,480 Bytes
b99458b
 
 
8103be7
b99458b
 
 
8103be7
 
b99458b
8103be7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b17b8f7
 
8103be7
b17b8f7
 
 
834fda3
8103be7
 
 
 
 
 
 
 
5b2681b
 
 
8103be7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c37b0d6
 
1d15344
8103be7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13e3b2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8103be7
13e3b2f
 
 
 
 
 
 
8103be7
 
 
 
 
 
 
 
 
 
 
 
1162de7
8103be7
 
 
e9ce3cc
 
8103be7
 
7d46eac
8103be7
 
 
 
 
 
 
 
 
 
 
 
54739a3
7d46eac
 
8103be7
 
 
 
 
 
 
 
7d46eac
8103be7
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import os
import re
import time
import json
from itertools import cycle

import torch
import gradio as gr
from urllib.parse import unquote 
from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList

from data import extract_leaves, split_document, handle_broken_output, clean_json_text, sync_empty_fields
from examples import examples as input_examples
from nuextract_logging import log_event


MAX_INPUT_SIZE = 10_000
MAX_NEW_TOKENS = 4_000
MAX_WINDOW_SIZE = 4_000

markdown_description = """
<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
</head>
<body>
    <img src="https://cdn.prod.website-files.com/638364a4e52e440048a9529c/64188f405afcf42d0b85b926_logo_numind_final.png" alt="NuMind Logo" style="vertical-align: middle;width: 200px; height: 50px;">
    <br>
    <ul>
        <li>NuMind is a startup developing custom information extraction solutions.</li>
        <li>NuExtract is a zero-shot model. See the blog posts for more info (<a href="https://numind.ai/blog/nuextract-a-foundation-model-for-structured-extraction">NuExtract</a>, <a href="https://numind.ai/blog/nuextract-1-5---multilingual-infinite-context-still-small-and-better-than-gpt-4o">NuExtract-v1.5</a>).</li>
        <li>We have started to deploy NuMind Enterprise to customize, serve, and monitor NuExtract privately. If that interests you, let's chat 😊.</li>
        <li><strong>Website</strong>: <a href="https://www.numind.ai/">https://www.numind.ai/</a></li>
    </ul>
    <h1>NuExtract-v1.5</h1>
    <p>NuExtract-v1.5 is a fine-tuning of Phi-3.5-mini-instruct, trained on a private high-quality dataset for structured information extraction. 
    It supports long documents and several languages (English, French, Spanish, German, Portuguese, and Italian). 
    To use the model, provide an input text and a JSON template describing the information you need to extract.</p>
    <ul>
        <li><strong>Model</strong>: <a href="https://huggingface.co/numind/NuExtract-v1.5">numind/NuExtract-v1.5</a></li>
    </ul>
    <i>⚠️ In this space we restrict the model inputs to a maximum length of 10k tokens, with anything over 4k being processed in a sliding window. For full model performance, self-host the model or contact us.</i>
    <br>
    <i>⚠️ The model is trained to assume a valid JSON template. Attempts to use invalid JSON could lead to unpredictable results.</i>
</body>
</html>
"""


def highlight_words(input_text, json_output):
    colors = cycle(["#90ee90", "#add8e6", "#ffb6c1", "#ffff99", "#ffa07a", "#20b2aa", "#87cefa", "#b0e0e6", "#dda0dd", "#ffdead"])
    color_map = {}
    highlighted_text = input_text

    leaves = extract_leaves(json_output)
    for path, value in leaves:
        path_key = tuple(path)
        if path_key not in color_map:
            color_map[path_key] = next(colors)
        color = color_map[path_key]

        escaped_value = re.escape(value).replace(r'\ ', r'\s+') # escape value and replace spaces with \s+
        pattern = rf"(?<=[ \n\t]){escaped_value}(?=[ \n\t\.\,\?\:\;])"
        replacement = f"<span style='background-color: {color};'>{unquote(value)}</span>"
        highlighted_text = re.sub(pattern, replacement, highlighted_text, flags=re.IGNORECASE)

    return highlighted_text

def predict_chunk(text, template, current, model, tokenizer):
    current = clean_json_text(current)

    input_llm =  f"<|input|>\n### Template:\n{template}\n### Current:\n{current}\n### Text:\n{text}\n\n<|output|>" + "{"
    input_ids = tokenizer(input_llm, return_tensors="pt", truncation=True, max_length=MAX_INPUT_SIZE).to("cuda")
    output = tokenizer.decode(model.generate(**input_ids, max_new_tokens=MAX_NEW_TOKENS)[0], skip_special_tokens=True)

    return clean_json_text(output.split("<|output|>")[1])

def sliding_window_prediction(template, text, model, tokenizer, window_size=4000, overlap=128):
    # Split text into chunks of n tokens
    tokens = tokenizer.tokenize(text)
    chunks = split_document(text, window_size, overlap, tokenizer)

    # Iterate over text chunks
    prev = template
    full_pred = ""
    
    for i, chunk in enumerate(chunks):
        print(f"Processing chunk {i}...")
        pred = predict_chunk(chunk, template, prev, model, tokenizer)

        # Handle broken output
        pred = handle_broken_output(pred, prev)
        
        # create highlighted text
        try:
            highlighted_pred = highlight_words(text, json.loads(pred))
        except:
            highlighted_pred = text

        # attempt json parsing
        template_dict = None
        pred_dict = None
        try:
            template_dict = json.loads(template)
        except:
            pass
        try:
            pred_dict = json.loads(pred)
        except:
            pass
        
        # Sync empty fields
        if template_dict and pred_dict:
            synced_pred = sync_empty_fields(pred_dict, template_dict)
            synced_pred = json.dumps(synced_pred, indent=4, ensure_ascii=False)
        elif pred_dict:
            synced_pred = json.dumps(pred_dict, indent=4, ensure_ascii=False)
        else:
            synced_pred = pred

        # Return progress, current prediction, and updated HTML
        yield f"Processed chunk {i+1}/{len(chunks)}", synced_pred, highlighted_pred

        # Iterate
        prev = pred


######

# Load the model and tokenizer
model_name = "numind/NuExtract-v1.5"
auth_token = os.environ.get("HF_TOKEN") or False
model = AutoModelForCausalLM.from_pretrained(model_name, 
                                             trust_remote_code=True, 
                                             torch_dtype=torch.bfloat16,
                                             device_map="auto", use_auth_token=auth_token)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=auth_token)
model.eval()

def gradio_interface_function(template, text, is_example):
    if len(tokenizer.tokenize(text)) > MAX_INPUT_SIZE:
        yield "", "Input text too long for space. Download model to use unrestricted.", ""
        return  # End the function since there was an error

    # Initialize the sliding window prediction process
    prediction_generator = sliding_window_prediction(template, text, model, tokenizer, window_size=MAX_WINDOW_SIZE)

    # Iterate over the generator to return values at each step
    for progress, full_pred, html_content in prediction_generator:
        # yield gr.update(value=chunk_info), gr.update(value=progress), gr.update(value=full_pred), gr.update(value=html_content)
        yield progress, full_pred, html_content

    if not is_example and os.environ.get("LOG_SERVER") is not None:
        log_event(text, template, full_pred)
        

# Set up the Gradio interface
iface = gr.Interface(
    description=markdown_description,
    fn=gradio_interface_function,
    inputs=[
        gr.Textbox(lines=2, placeholder="Enter Template here...", label="Template"),
        gr.Textbox(lines=2, placeholder="Enter input Text here...", label="Input Text"),
        gr.Checkbox(label="Is Example?", visible=False),
    ],
    outputs=[
        gr.Textbox(label="Progress"),
        gr.Textbox(label="Model Output"),
        gr.HTML(label="Model Output with Highlighted Words"),
    ],
    examples=input_examples,
    # live=True  # Enable real-time updates
)

iface.launch(debug=True, share=True)