Spaces:
Running
Running
Allow the user to specify a partially rewritten document.
Browse files
app.py
CHANGED
@@ -28,9 +28,10 @@ def get_model(model_name):
|
|
28 |
|
29 |
prompt = st.text_area("Prompt", "Rewrite this document to be more clear and concise.")
|
30 |
doc = st.text_area("Document", "This is a document that I would like to have rewritten to be more concise.")
|
|
|
31 |
|
32 |
|
33 |
-
def get_spans_local(prompt, doc):
|
34 |
import torch
|
35 |
|
36 |
tokenizer = get_tokenizer(model_name)
|
@@ -46,8 +47,10 @@ def get_spans_local(prompt, doc):
|
|
46 |
tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt")[0]
|
47 |
assert len(tokenized_chat.shape) == 1
|
48 |
|
49 |
-
|
50 |
-
|
|
|
|
|
51 |
|
52 |
# Call the model
|
53 |
with torch.no_grad():
|
@@ -72,18 +75,22 @@ def get_spans_local(prompt, doc):
|
|
72 |
length_so_far += len(token)
|
73 |
return spans
|
74 |
|
75 |
-
def get_highlights_api(prompt, doc):
|
76 |
# Make a request to the API. prompt and doc are query parameters:
|
77 |
# https://tools.kenarnold.org/api/highlights?prompt=Rewrite%20this%20document&doc=This%20is%20a%20document
|
78 |
# The response is a JSON array
|
79 |
import requests
|
80 |
-
response = requests.get("https://tools.kenarnold.org/api/highlights", params=dict(prompt=prompt, doc=doc))
|
81 |
return response.json()['highlights']
|
82 |
|
83 |
if model_name == 'API':
|
84 |
-
spans = get_highlights_api(prompt, doc)
|
85 |
else:
|
86 |
-
spans = get_spans_local(prompt, doc)
|
|
|
|
|
|
|
|
|
87 |
|
88 |
highest_loss = max(span['token_loss'] for span in spans[1:])
|
89 |
for span in spans:
|
@@ -99,6 +106,5 @@ for span in spans:
|
|
99 |
)
|
100 |
html_out = f"<p style=\"background: white;\">{html_out}</p>"
|
101 |
|
102 |
-
st.subheader("Rewritten document")
|
103 |
st.write(html_out, unsafe_allow_html=True)
|
104 |
st.write(pd.DataFrame(spans)[['token', 'token_loss', 'most_likely_token', 'loss_ratio']])
|
|
|
28 |
|
29 |
prompt = st.text_area("Prompt", "Rewrite this document to be more clear and concise.")
|
30 |
doc = st.text_area("Document", "This is a document that I would like to have rewritten to be more concise.")
|
31 |
+
updated_doc = st.text_area("Updated Doc", help="Your edited document. Leave this blank to use your original document.")
|
32 |
|
33 |
|
34 |
+
def get_spans_local(prompt, doc, updated_doc):
|
35 |
import torch
|
36 |
|
37 |
tokenizer = get_tokenizer(model_name)
|
|
|
47 |
tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt")[0]
|
48 |
assert len(tokenized_chat.shape) == 1
|
49 |
|
50 |
+
if len(updated_doc.strip()) == 0:
|
51 |
+
updated_doc = doc
|
52 |
+
updated_doc_ids = tokenizer(updated_doc, return_tensors='pt')['input_ids'][0]
|
53 |
+
joined_ids = torch.cat([tokenized_chat, updated_doc_ids[1:]])
|
54 |
|
55 |
# Call the model
|
56 |
with torch.no_grad():
|
|
|
75 |
length_so_far += len(token)
|
76 |
return spans
|
77 |
|
78 |
+
def get_highlights_api(prompt, doc, updated_doc):
|
79 |
# Make a request to the API. prompt and doc are query parameters:
|
80 |
# https://tools.kenarnold.org/api/highlights?prompt=Rewrite%20this%20document&doc=This%20is%20a%20document
|
81 |
# The response is a JSON array
|
82 |
import requests
|
83 |
+
response = requests.get("https://tools.kenarnold.org/api/highlights", params=dict(prompt=prompt, doc=doc, updated_doc=updated_doc))
|
84 |
return response.json()['highlights']
|
85 |
|
86 |
if model_name == 'API':
|
87 |
+
spans = get_highlights_api(prompt, doc, updated_doc)
|
88 |
else:
|
89 |
+
spans = get_spans_local(prompt, doc, updated_doc)
|
90 |
+
|
91 |
+
if len(spans) < 2:
|
92 |
+
st.write("No spans found.")
|
93 |
+
st.stop()
|
94 |
|
95 |
highest_loss = max(span['token_loss'] for span in spans[1:])
|
96 |
for span in spans:
|
|
|
106 |
)
|
107 |
html_out = f"<p style=\"background: white;\">{html_out}</p>"
|
108 |
|
|
|
109 |
st.write(html_out, unsafe_allow_html=True)
|
110 |
st.write(pd.DataFrame(spans)[['token', 'token_loss', 'most_likely_token', 'loss_ratio']])
|