raynardj's picture
Update app.py
d5ccfc3
raw
history blame
6.9 kB
import streamlit as st
import pandas as pd
from pathlib import Path
import requests
import base64
from requests.auth import HTTPBasicAuth
import torch
st.set_page_config(layout="wide")
@st.cache(allow_output_mutation=True)
def load_model():
from transformers import (
EncoderDecoderModel,
AutoTokenizer,
)
PRETRAINED = "raynardj/wenyanwen-ancient-translate-to-modern"
tokenizer = AutoTokenizer.from_pretrained(PRETRAINED)
model = EncoderDecoderModel.from_pretrained(PRETRAINED)
return tokenizer, model
tokenizer, model = load_model()
def inference(text):
print(f"from: {text}")
tk_kwargs = dict(
truncation=True,
max_length=168,
padding="max_length",
return_tensors='pt')
inputs = tokenizer([text, ], **tk_kwargs)
with torch.no_grad():
new = tokenizer.batch_decode(
model.generate(
inputs.input_ids,
attention_mask=inputs.attention_mask,
num_beams=3,
max_length=256,
bos_token_id=101,
eos_token_id=tokenizer.sep_token_id,
pad_token_id=tokenizer.pad_token_id,
), skip_special_tokens=True)[0].replace(" ", "")
print(f"to: {new}")
return new
@st.cache
def get_file_df():
file_df = pd.read_csv("meta.csv")
return file_df
file_df = get_file_df()
st.sidebar.title("【隨無涯】")
st.sidebar.markdown("""
* 朕自庖[🤗 模型](https://huggingface.co/raynardj/wenyanwen-ancient-translate-to-modern), [⭐️ 訓習處](https://github.com/raynardj/yuan)
* 📚 充棟汗牛,取自[殆知閣](http://www.daizhige.org/),[github api](https://github.com/garychowcmu/daizhigev20)
""")
c2 = st.container()
c2.write("The entirety of ancient Chinese literature, with a modern translator at your side.")
st.markdown("""---""")
c = st.container()
USER_ID = st.secrets["USER_ID"]
SECRET = st.secrets["SECRET"]
@st.cache
def get_maps():
file_obj_hash_map = dict(file_df[["filepath", "obj_hash"]].values)
file_size_map = dict(file_df[["filepath", "fsize"]].values)
return file_obj_hash_map, file_size_map
file_obj_hash_map, file_size_map = get_maps()
def show_file_size(size: int):
if size < 1024:
return f"{size} B"
elif size < 1024*1024:
return f"{size//1024} KB"
else:
return f"{size/1024//1024} MB"
@st.cache(max_entries=100, allow_output_mutation=True)
def fetch_file(path):
# reading from local path first
if (Path("data")/path).exists():
with open(Path("data")/path, "r") as f:
return f.read()
# read from github api
obj_hash = file_obj_hash_map[path]
auth = HTTPBasicAuth(USER_ID, SECRET)
url = f"https://api.github.com/repos/garychowcmu/daizhigev20/git/blobs/{obj_hash}"
print(f"requesting {url}")
r = requests.get(url, auth=auth)
if r.status_code == 200:
data = r.json()
content = base64.b64decode(data['content']).decode('utf-8')
return content
else:
r.raise_for_status()
@st.cache(allow_output_mutation=True, max_entries=100)
def fetch_from_df(sub_paths: str = ""):
sub_df = file_df.copy()
for idx, step in enumerate(sub_paths):
sub_df.query(f"col_{idx} == '{step}'", inplace=True)
if len(sub_df) == 0:
return None
return list(sub_df[f"col_{len(sub_paths)}"].unique())
def show_filepath(filepath: str):
text = fetch_file(filepath)
c.markdown(
f"""<pre style='white-space:pre-wrap;max-height:300px;overflow-y:auto'>{text}</pre>""", unsafe_allow_html=True)
if st.sidebar.selectbox(label="何以尋跡 How to search",options=["以類尋書 category","書名求書 search"])=="以類尋書 category":
# root_data = fetch_from_github()
if 'pathway' in st.session_state:
pass
else:
st.session_state.pathway = []
path_text = st.sidebar.text("/".join(st.session_state.pathway))
def reset_path():
st.session_state.pathway = []
path_text.text(st.session_state.pathway)
if st.sidebar.button("還至初錄(back to root)"):
reset_path()
def display_tree():
sublist = fetch_from_df(st.session_state.pathway)
dropdown = st.sidebar.selectbox("【擇書 choose】", options=sublist)
with st.spinner("書非借不能讀也..."):
st.session_state.pathway.append(dropdown)
if dropdown.endswith('.txt'):
filepath = "/".join(st.session_state.pathway)
file_size = file_size_map[filepath]
with st.spinner(f"Load 載文:{filepath},({show_file_size(file_size)})"):
# if file size is too large, we will not load it
if file_size > 3*1024*1024:
print(f"skip {filepath}")
urlpath = filepath.replace(".txt", ".html")
dzg = f"http://www.daizhige.org/{urlpath}"
st.markdown(f"File too big 其文碩而難載,不能為之,[往 殆知閣]({dzg}), 或擇他書")
reset_path()
return None
path_text.text(filepath)
print(f"read {filepath}")
text = fetch_file(filepath)
# create markdown with max heights
c.markdown(
f"""<pre style='white-space:pre-wrap;max-height:300px;overflow-y:auto'>{text}</pre>""", unsafe_allow_html=True
)
reset_path()
else:
sub_list = fetch_from_df(
st.session_state.pathway)
path_text.text("/".join(st.session_state.pathway))
display_tree()
display_tree()
else:
def search_kw():
result = file_df[file_df.filepath.str.contains(st.session_state.kw)].reset_index(drop=True)
if len(result) == 0:
st.sidebar.write(f"尋之不得:{st.session_state.kw}")
else:
filepath = st.sidebar.selectbox("選一書名", options=list(result.head(15).filepath))
show_filepath(filepath)
def loading_with_search():
kw = st.sidebar.text_input("書名求書 Search", value="楞伽经")
st.session_state.kw = kw
search_kw()
loading_with_search()
def translate_text():
if c2.button("【曉文達義 Translate】"):
if cc:
if len(cc) > 168:
c2.write(f"句甚長 不得過百又六十八字 Sentence too long, should be less than 168 characters")
else:
c2.markdown(f"""```{inference(cc)}```""")
else:
c2.write("【入難曉之文字 Please input sentence for translating】")
cc = c2.text_area("【入難曉之文字 Input sentence】", height=150)
translate_text()