raynardj's picture
Update app.py
d5ccfc3
import streamlit as st
import pandas as pd
from pathlib import Path
import requests
import base64
from requests.auth import HTTPBasicAuth
import torch
st.set_page_config(layout="wide")
@st.cache(allow_output_mutation=True)
def load_model():
from transformers import (
EncoderDecoderModel,
AutoTokenizer,
)
PRETRAINED = "raynardj/wenyanwen-ancient-translate-to-modern"
tokenizer = AutoTokenizer.from_pretrained(PRETRAINED)
model = EncoderDecoderModel.from_pretrained(PRETRAINED)
return tokenizer, model
tokenizer, model = load_model()
def inference(text):
print(f"from: {text}")
tk_kwargs = dict(
truncation=True,
max_length=168,
padding="max_length",
return_tensors='pt')
inputs = tokenizer([text, ], **tk_kwargs)
with torch.no_grad():
new = tokenizer.batch_decode(
model.generate(
inputs.input_ids,
attention_mask=inputs.attention_mask,
num_beams=3,
max_length=256,
bos_token_id=101,
eos_token_id=tokenizer.sep_token_id,
pad_token_id=tokenizer.pad_token_id,
), skip_special_tokens=True)[0].replace(" ", "")
print(f"to: {new}")
return new
@st.cache
def get_file_df():
file_df = pd.read_csv("meta.csv")
return file_df
file_df = get_file_df()
st.sidebar.title("【隨無涯】")
st.sidebar.markdown("""
* 朕自庖[🤗 模型](https://huggingface.co/raynardj/wenyanwen-ancient-translate-to-modern), [⭐️ 訓習處](https://github.com/raynardj/yuan)
* 📚 充棟汗牛,取自[殆知閣](http://www.daizhige.org/),[github api](https://github.com/garychowcmu/daizhigev20)
""")
c2 = st.container()
c2.write("The entirety of ancient Chinese literature, with a modern translator at your side.")
st.markdown("""---""")
c = st.container()
USER_ID = st.secrets["USER_ID"]
SECRET = st.secrets["SECRET"]
@st.cache
def get_maps():
file_obj_hash_map = dict(file_df[["filepath", "obj_hash"]].values)
file_size_map = dict(file_df[["filepath", "fsize"]].values)
return file_obj_hash_map, file_size_map
file_obj_hash_map, file_size_map = get_maps()
def show_file_size(size: int):
if size < 1024:
return f"{size} B"
elif size < 1024*1024:
return f"{size//1024} KB"
else:
return f"{size/1024//1024} MB"
@st.cache(max_entries=100, allow_output_mutation=True)
def fetch_file(path):
# reading from local path first
if (Path("data")/path).exists():
with open(Path("data")/path, "r") as f:
return f.read()
# read from github api
obj_hash = file_obj_hash_map[path]
auth = HTTPBasicAuth(USER_ID, SECRET)
url = f"https://api.github.com/repos/garychowcmu/daizhigev20/git/blobs/{obj_hash}"
print(f"requesting {url}")
r = requests.get(url, auth=auth)
if r.status_code == 200:
data = r.json()
content = base64.b64decode(data['content']).decode('utf-8')
return content
else:
r.raise_for_status()
@st.cache(allow_output_mutation=True, max_entries=100)
def fetch_from_df(sub_paths: str = ""):
sub_df = file_df.copy()
for idx, step in enumerate(sub_paths):
sub_df.query(f"col_{idx} == '{step}'", inplace=True)
if len(sub_df) == 0:
return None
return list(sub_df[f"col_{len(sub_paths)}"].unique())
def show_filepath(filepath: str):
text = fetch_file(filepath)
c.markdown(
f"""<pre style='white-space:pre-wrap;max-height:300px;overflow-y:auto'>{text}</pre>""", unsafe_allow_html=True)
if st.sidebar.selectbox(label="何以尋跡 How to search",options=["以類尋書 category","書名求書 search"])=="以類尋書 category":
# root_data = fetch_from_github()
if 'pathway' in st.session_state:
pass
else:
st.session_state.pathway = []
path_text = st.sidebar.text("/".join(st.session_state.pathway))
def reset_path():
st.session_state.pathway = []
path_text.text(st.session_state.pathway)
if st.sidebar.button("還至初錄(back to root)"):
reset_path()
def display_tree():
sublist = fetch_from_df(st.session_state.pathway)
dropdown = st.sidebar.selectbox("【擇書 choose】", options=sublist)
with st.spinner("書非借不能讀也..."):
st.session_state.pathway.append(dropdown)
if dropdown.endswith('.txt'):
filepath = "/".join(st.session_state.pathway)
file_size = file_size_map[filepath]
with st.spinner(f"Load 載文:{filepath},({show_file_size(file_size)})"):
# if file size is too large, we will not load it
if file_size > 3*1024*1024:
print(f"skip {filepath}")
urlpath = filepath.replace(".txt", ".html")
dzg = f"http://www.daizhige.org/{urlpath}"
st.markdown(f"File too big 其文碩而難載,不能為之,[往 殆知閣]({dzg}), 或擇他書")
reset_path()
return None
path_text.text(filepath)
print(f"read {filepath}")
text = fetch_file(filepath)
# create markdown with max heights
c.markdown(
f"""<pre style='white-space:pre-wrap;max-height:300px;overflow-y:auto'>{text}</pre>""", unsafe_allow_html=True
)
reset_path()
else:
sub_list = fetch_from_df(
st.session_state.pathway)
path_text.text("/".join(st.session_state.pathway))
display_tree()
display_tree()
else:
def search_kw():
result = file_df[file_df.filepath.str.contains(st.session_state.kw)].reset_index(drop=True)
if len(result) == 0:
st.sidebar.write(f"尋之不得:{st.session_state.kw}")
else:
filepath = st.sidebar.selectbox("選一書名", options=list(result.head(15).filepath))
show_filepath(filepath)
def loading_with_search():
kw = st.sidebar.text_input("書名求書 Search", value="楞伽经")
st.session_state.kw = kw
search_kw()
loading_with_search()
def translate_text():
if c2.button("【曉文達義 Translate】"):
if cc:
if len(cc) > 168:
c2.write(f"句甚長 不得過百又六十八字 Sentence too long, should be less than 168 characters")
else:
c2.markdown(f"""```{inference(cc)}```""")
else:
c2.write("【入難曉之文字 Please input sentence for translating】")
cc = c2.text_area("【入難曉之文字 Input sentence】", height=150)
translate_text()