Spaces:
Runtime error
Runtime error
import streamlit as st | |
import pandas as pd | |
from pathlib import Path | |
import requests | |
import base64 | |
from requests.auth import HTTPBasicAuth | |
import torch | |
st.set_page_config(layout="wide") | |
def load_model(): | |
from transformers import ( | |
EncoderDecoderModel, | |
AutoTokenizer, | |
) | |
PRETRAINED = "raynardj/wenyanwen-ancient-translate-to-modern" | |
tokenizer = AutoTokenizer.from_pretrained(PRETRAINED) | |
model = EncoderDecoderModel.from_pretrained(PRETRAINED) | |
return tokenizer, model | |
tokenizer, model = load_model() | |
def inference(text): | |
print(f"from: {text}") | |
tk_kwargs = dict( | |
truncation=True, | |
max_length=168, | |
padding="max_length", | |
return_tensors='pt') | |
inputs = tokenizer([text, ], **tk_kwargs) | |
with torch.no_grad(): | |
new = tokenizer.batch_decode( | |
model.generate( | |
inputs.input_ids, | |
attention_mask=inputs.attention_mask, | |
num_beams=3, | |
max_length=256, | |
bos_token_id=101, | |
eos_token_id=tokenizer.sep_token_id, | |
pad_token_id=tokenizer.pad_token_id, | |
), skip_special_tokens=True)[0].replace(" ", "") | |
print(f"to: {new}") | |
return new | |
def get_file_df(): | |
file_df = pd.read_csv("meta.csv") | |
return file_df | |
file_df = get_file_df() | |
st.sidebar.title("【隨無涯】") | |
st.sidebar.markdown(""" | |
* 朕自庖[🤗 模型](https://huggingface.co/raynardj/wenyanwen-ancient-translate-to-modern), [⭐️ 訓習處](https://github.com/raynardj/yuan) | |
* 📚 充棟汗牛,取自[殆知閣](http://www.daizhige.org/),[github api](https://github.com/garychowcmu/daizhigev20) | |
""") | |
c2 = st.container() | |
c2.write("The entirety of ancient Chinese literature, with a modern translator at your side.") | |
st.markdown("""---""") | |
c = st.container() | |
USER_ID = st.secrets["USER_ID"] | |
SECRET = st.secrets["SECRET"] | |
def get_maps(): | |
file_obj_hash_map = dict(file_df[["filepath", "obj_hash"]].values) | |
file_size_map = dict(file_df[["filepath", "fsize"]].values) | |
return file_obj_hash_map, file_size_map | |
file_obj_hash_map, file_size_map = get_maps() | |
def show_file_size(size: int): | |
if size < 1024: | |
return f"{size} B" | |
elif size < 1024*1024: | |
return f"{size//1024} KB" | |
else: | |
return f"{size/1024//1024} MB" | |
def fetch_file(path): | |
# reading from local path first | |
if (Path("data")/path).exists(): | |
with open(Path("data")/path, "r") as f: | |
return f.read() | |
# read from github api | |
obj_hash = file_obj_hash_map[path] | |
auth = HTTPBasicAuth(USER_ID, SECRET) | |
url = f"https://api.github.com/repos/garychowcmu/daizhigev20/git/blobs/{obj_hash}" | |
print(f"requesting {url}") | |
r = requests.get(url, auth=auth) | |
if r.status_code == 200: | |
data = r.json() | |
content = base64.b64decode(data['content']).decode('utf-8') | |
return content | |
else: | |
r.raise_for_status() | |
def fetch_from_df(sub_paths: str = ""): | |
sub_df = file_df.copy() | |
for idx, step in enumerate(sub_paths): | |
sub_df.query(f"col_{idx} == '{step}'", inplace=True) | |
if len(sub_df) == 0: | |
return None | |
return list(sub_df[f"col_{len(sub_paths)}"].unique()) | |
def show_filepath(filepath: str): | |
text = fetch_file(filepath) | |
c.markdown( | |
f"""<pre style='white-space:pre-wrap;max-height:300px;overflow-y:auto'>{text}</pre>""", unsafe_allow_html=True) | |
if st.sidebar.selectbox(label="何以尋跡 How to search",options=["以類尋書 category","書名求書 search"])=="以類尋書 category": | |
# root_data = fetch_from_github() | |
if 'pathway' in st.session_state: | |
pass | |
else: | |
st.session_state.pathway = [] | |
path_text = st.sidebar.text("/".join(st.session_state.pathway)) | |
def reset_path(): | |
st.session_state.pathway = [] | |
path_text.text(st.session_state.pathway) | |
if st.sidebar.button("還至初錄(back to root)"): | |
reset_path() | |
def display_tree(): | |
sublist = fetch_from_df(st.session_state.pathway) | |
dropdown = st.sidebar.selectbox("【擇書 choose】", options=sublist) | |
with st.spinner("書非借不能讀也..."): | |
st.session_state.pathway.append(dropdown) | |
if dropdown.endswith('.txt'): | |
filepath = "/".join(st.session_state.pathway) | |
file_size = file_size_map[filepath] | |
with st.spinner(f"Load 載文:{filepath},({show_file_size(file_size)})"): | |
# if file size is too large, we will not load it | |
if file_size > 3*1024*1024: | |
print(f"skip {filepath}") | |
urlpath = filepath.replace(".txt", ".html") | |
dzg = f"http://www.daizhige.org/{urlpath}" | |
st.markdown(f"File too big 其文碩而難載,不能為之,[往 殆知閣]({dzg}), 或擇他書") | |
reset_path() | |
return None | |
path_text.text(filepath) | |
print(f"read {filepath}") | |
text = fetch_file(filepath) | |
# create markdown with max heights | |
c.markdown( | |
f"""<pre style='white-space:pre-wrap;max-height:300px;overflow-y:auto'>{text}</pre>""", unsafe_allow_html=True | |
) | |
reset_path() | |
else: | |
sub_list = fetch_from_df( | |
st.session_state.pathway) | |
path_text.text("/".join(st.session_state.pathway)) | |
display_tree() | |
display_tree() | |
else: | |
def search_kw(): | |
result = file_df[file_df.filepath.str.contains(st.session_state.kw)].reset_index(drop=True) | |
if len(result) == 0: | |
st.sidebar.write(f"尋之不得:{st.session_state.kw}") | |
else: | |
filepath = st.sidebar.selectbox("選一書名", options=list(result.head(15).filepath)) | |
show_filepath(filepath) | |
def loading_with_search(): | |
kw = st.sidebar.text_input("書名求書 Search", value="楞伽经") | |
st.session_state.kw = kw | |
search_kw() | |
loading_with_search() | |
def translate_text(): | |
if c2.button("【曉文達義 Translate】"): | |
if cc: | |
if len(cc) > 168: | |
c2.write(f"句甚長 不得過百又六十八字 Sentence too long, should be less than 168 characters") | |
else: | |
c2.markdown(f"""```{inference(cc)}```""") | |
else: | |
c2.write("【入難曉之文字 Please input sentence for translating】") | |
cc = c2.text_area("【入難曉之文字 Input sentence】", height=150) | |
translate_text() | |