Spaces:
Running
Running
File size: 3,972 Bytes
18f01d3 76b5bb9 18f01d3 aa3a290 18f01d3 fa263f0 18f01d3 6f71907 fa263f0 18f01d3 88c770b 18f01d3 88c770b fa263f0 aa3a290 88c770b aa3a290 88c770b aa3a290 88c770b aa3a290 88c770b aa3a290 88c770b 18f01d3 92076bd 18f01d3 a321d69 3ff8b95 58552f2 3ff8b95 92076bd 18f01d3 92076bd aa3a290 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
import streamlit as st
import pandas as pd
import html
model_options = [
'API',
'google/gemma-1.1-2b-it',
'google/gemma-1.1-7b-it'
]
model_name = st.selectbox("Select a model", model_options + ['other'])
if model_name == 'other':
model_name = st.text_input("Enter model name", model_options[0])
@st.cache_resource
def get_tokenizer(model_name):
from transformers import AutoTokenizer
return AutoTokenizer.from_pretrained(model_name).from_pretrained(model_name)
@st.cache_resource
def get_model(model_name):
import torch
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto', torch_dtype=torch.bfloat16)
print(f"Loaded model, {model.num_parameters():,d} parameters.")
return model
prompt = st.text_area("Prompt", "Rewrite this document to be more clear and concise.")
doc = st.text_area("Document", "This is a document that I would like to have rewritten to be more concise.")
updated_doc = st.text_area("Updated Doc", help="Your edited document. Leave this blank to use your original document.")
def get_spans_local(prompt, doc, updated_doc):
import torch
tokenizer = get_tokenizer(model_name)
model = get_model(model_name)
messages = [
{
"role": "user",
"content": f"{prompt}\n\n{doc}",
},
]
tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt")[0]
assert len(tokenized_chat.shape) == 1
if len(updated_doc.strip()) == 0:
updated_doc = doc
updated_doc_ids = tokenizer(updated_doc, return_tensors='pt')['input_ids'][0]
joined_ids = torch.cat([tokenized_chat, updated_doc_ids[1:]])
# Call the model
with torch.no_grad():
logits = model(joined_ids[None].to(model.device)).logits[0].cpu()
spans = []
length_so_far = 0
for idx in range(len(tokenized_chat), len(joined_ids)):
probs = logits[idx - 1].softmax(dim=-1)
token_id = joined_ids[idx]
token = tokenizer.decode(token_id)
token_loss = -probs[token_id].log().item()
most_likely_token_id = probs.argmax()
print(idx, token, token_loss, tokenizer.decode(most_likely_token_id))
spans.append(dict(
start=length_so_far,
end=length_so_far + len(token),
token=token,
token_loss=token_loss,
most_likely_token=tokenizer.decode(most_likely_token_id)
))
length_so_far += len(token)
return spans
def get_highlights_api(prompt, doc, updated_doc):
# Make a request to the API. prompt and doc are query parameters:
# https://tools.kenarnold.org/api/highlights?prompt=Rewrite%20this%20document&doc=This%20is%20a%20document
# The response is a JSON array
import requests
response = requests.get("https://tools.kenarnold.org/api/highlights", params=dict(prompt=prompt, doc=doc, updated_doc=updated_doc))
return response.json()['highlights']
if model_name == 'API':
spans = get_highlights_api(prompt, doc, updated_doc)
else:
spans = get_spans_local(prompt, doc, updated_doc)
if len(spans) < 2:
st.write("No spans found.")
st.stop()
highest_loss = max(span['token_loss'] for span in spans[1:])
for span in spans:
span['loss_ratio'] = span['token_loss'] / highest_loss
html_out = ''
for span in spans:
is_different = span['token'] != span['most_likely_token']
html_out += '<span style="color: {color}" title="{title}">{orig_token}</span>'.format(
color="blue" if is_different else "black",
title=html.escape(span["most_likely_token"]).replace('\n', ' '),
orig_token=html.escape(span["token"]).replace('\n', '<br>')
)
html_out = f"<p style=\"background: white;\">{html_out}</p>"
st.write(html_out, unsafe_allow_html=True)
st.write(pd.DataFrame(spans)[['token', 'token_loss', 'most_likely_token', 'loss_ratio']])
|