kcarnold commited on
Commit
88c770b
·
1 Parent(s): 6f71907

Allow the user to specify a partially rewritten document.

Browse files
Files changed (1) hide show
  1. app.py +14 -8
app.py CHANGED
@@ -28,9 +28,10 @@ def get_model(model_name):
28
 
29
  prompt = st.text_area("Prompt", "Rewrite this document to be more clear and concise.")
30
  doc = st.text_area("Document", "This is a document that I would like to have rewritten to be more concise.")
 
31
 
32
 
33
- def get_spans_local(prompt, doc):
34
  import torch
35
 
36
  tokenizer = get_tokenizer(model_name)
@@ -46,8 +47,10 @@ def get_spans_local(prompt, doc):
46
  tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt")[0]
47
  assert len(tokenized_chat.shape) == 1
48
 
49
- doc_ids = tokenizer(doc, return_tensors='pt')['input_ids'][0]
50
- joined_ids = torch.cat([tokenized_chat, doc_ids[1:]])
 
 
51
 
52
  # Call the model
53
  with torch.no_grad():
@@ -72,18 +75,22 @@ def get_spans_local(prompt, doc):
72
  length_so_far += len(token)
73
  return spans
74
 
75
- def get_highlights_api(prompt, doc):
76
  # Make a request to the API. prompt and doc are query parameters:
77
  # https://tools.kenarnold.org/api/highlights?prompt=Rewrite%20this%20document&doc=This%20is%20a%20document
78
  # The response is a JSON array
79
  import requests
80
- response = requests.get("https://tools.kenarnold.org/api/highlights", params=dict(prompt=prompt, doc=doc))
81
  return response.json()['highlights']
82
 
83
  if model_name == 'API':
84
- spans = get_highlights_api(prompt, doc)
85
  else:
86
- spans = get_spans_local(prompt, doc)
 
 
 
 
87
 
88
  highest_loss = max(span['token_loss'] for span in spans[1:])
89
  for span in spans:
@@ -99,6 +106,5 @@ for span in spans:
99
  )
100
  html_out = f"<p style=\"background: white;\">{html_out}</p>"
101
 
102
- st.subheader("Rewritten document")
103
  st.write(html_out, unsafe_allow_html=True)
104
  st.write(pd.DataFrame(spans)[['token', 'token_loss', 'most_likely_token', 'loss_ratio']])
 
28
 
29
  prompt = st.text_area("Prompt", "Rewrite this document to be more clear and concise.")
30
  doc = st.text_area("Document", "This is a document that I would like to have rewritten to be more concise.")
31
+ updated_doc = st.text_area("Updated Doc", help="Your edited document. Leave this blank to use your original document.")
32
 
33
 
34
+ def get_spans_local(prompt, doc, updated_doc):
35
  import torch
36
 
37
  tokenizer = get_tokenizer(model_name)
 
47
  tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt")[0]
48
  assert len(tokenized_chat.shape) == 1
49
 
50
+ if len(updated_doc.strip()) == 0:
51
+ updated_doc = doc
52
+ updated_doc_ids = tokenizer(updated_doc, return_tensors='pt')['input_ids'][0]
53
+ joined_ids = torch.cat([tokenized_chat, updated_doc_ids[1:]])
54
 
55
  # Call the model
56
  with torch.no_grad():
 
75
  length_so_far += len(token)
76
  return spans
77
 
78
+ def get_highlights_api(prompt, doc, updated_doc):
79
  # Make a request to the API. prompt and doc are query parameters:
80
  # https://tools.kenarnold.org/api/highlights?prompt=Rewrite%20this%20document&doc=This%20is%20a%20document
81
  # The response is a JSON array
82
  import requests
83
+ response = requests.get("https://tools.kenarnold.org/api/highlights", params=dict(prompt=prompt, doc=doc, updated_doc=updated_doc))
84
  return response.json()['highlights']
85
 
86
  if model_name == 'API':
87
+ spans = get_highlights_api(prompt, doc, updated_doc)
88
  else:
89
+ spans = get_spans_local(prompt, doc, updated_doc)
90
+
91
+ if len(spans) < 2:
92
+ st.write("No spans found.")
93
+ st.stop()
94
 
95
  highest_loss = max(span['token_loss'] for span in spans[1:])
96
  for span in spans:
 
106
  )
107
  html_out = f"<p style=\"background: white;\">{html_out}</p>"
108
 
 
109
  st.write(html_out, unsafe_allow_html=True)
110
  st.write(pd.DataFrame(spans)[['token', 'token_loss', 'most_likely_token', 'loss_ratio']])