Spaces:
Runtime error
Runtime error
import streamlit as st | |
import pandas as pd | |
from pathlib import Path | |
import requests | |
import base64 | |
from requests.auth import HTTPBasicAuth | |
import torch | |
st.set_page_config(layout="wide") | |
def load_model(): | |
from transformers import ( | |
EncoderDecoderModel, | |
AutoTokenizer | |
) | |
PRETRAINED = "raynardj/wenyanwen-ancient-translate-to-modern" | |
tokenizer = AutoTokenizer.from_pretrained(PRETRAINED) | |
model = EncoderDecoderModel.from_pretrained(PRETRAINED) | |
return tokenizer, model | |
tokenizer, model = load_model() | |
def inference(text): | |
tk_kwargs = dict( | |
truncation=True, | |
max_length=168, | |
padding="max_length", | |
return_tensors='pt') | |
inputs = tokenizer([text, ], **tk_kwargs) | |
with torch.no_grad(): | |
return tokenizer.batch_decode( | |
model.generate( | |
inputs.input_ids, | |
attention_mask=inputs.attention_mask, | |
num_beams=3, | |
max_length=256, | |
bos_token_id=101, | |
eos_token_id=tokenizer.sep_token_id, | |
pad_token_id=tokenizer.pad_token_id, | |
), skip_special_tokens=True)[0].replace(" ", "") | |
def get_file_df(): | |
file_df = pd.read_csv("meta.csv") | |
return file_df | |
file_df = get_file_df() | |
st.sidebar.title("【随无涯】") | |
st.sidebar.markdown(""" | |
* 朕亲自下厨的[🤗 翻译模型](https://github.com/raynardj/wenyanwen-ancient-translate-to-modern), [⭐️ 训练笔记](https://github.com/raynardj/yuan) | |
* 📚 书籍来自 [殆知阁](http://www.daizhige.org/),文本的[github api](https://github.com/garychowcmu/daizhigev20) | |
""") | |
c2 = st.container() | |
c = st.container() | |
USER_ID = st.secrets["USER_ID"] | |
SECRET = st.secrets["SECRET"] | |
def get_maps(): | |
file_obj_hash_map = dict(file_df[["filepath", "obj_hash"]].values) | |
file_size_map = dict(file_df[["filepath", "fsize"]].values) | |
return file_obj_hash_map, file_size_map | |
file_obj_hash_map, file_size_map = get_maps() | |
def show_file_size(size: int): | |
if size < 1024: | |
return f"{size} B" | |
elif size < 1024*1024: | |
return f"{size//1024} KB" | |
else: | |
return f"{size/1024//1024} MB" | |
def fetch_file(path): | |
# reading from local path first | |
if (Path("data")/path).exists(): | |
with open(Path("data")/path, "r") as f: | |
return f.read() | |
# read from github api | |
obj_hash = file_obj_hash_map[path] | |
auth = HTTPBasicAuth(USER_ID, SECRET) | |
url = f"https://api.github.com/repos/garychowcmu/daizhigev20/git/blobs/{obj_hash}" | |
r = requests.get(url, auth=auth) | |
if r.status_code == 200: | |
data = r.json() | |
content = base64.b64decode(data['content']).decode('utf-8') | |
return content | |
else: | |
r.raise_for_status() | |
def fetch_from_df(sub_paths: str = ""): | |
sub_df = file_df.copy() | |
for idx, step in enumerate(sub_paths): | |
sub_df.query(f"col_{idx} == '{step}'", inplace=True) | |
if len(sub_df) == 0: | |
return None | |
return list(sub_df[f"col_{len(sub_paths)}"].unique()) | |
# root_data = fetch_from_github() | |
if 'pathway' in st.session_state: | |
pass | |
else: | |
st.session_state.pathway = [] | |
path_text = st.sidebar.text("/".join(st.session_state.pathway)) | |
def reset_path(): | |
st.session_state.pathway = [] | |
path_text.text(st.session_state.pathway) | |
if st.sidebar.button("回到根目录"): | |
reset_path() | |
st.session_state.translating = False | |
def display_tree(): | |
sublist = fetch_from_df(st.session_state.pathway) | |
dropdown = st.sidebar.selectbox("【选书】", options=sublist) | |
with st.spinner("加载中..."): | |
st.session_state.pathway.append(dropdown) | |
if dropdown.endswith('.txt'): | |
filepath = "/".join(st.session_state.pathway) | |
file_size = file_size_map[filepath] | |
with st.spinner(f"loading file:{filepath},({show_file_size(file_size)})"): | |
# if file size is too large, we will not load it | |
if file_size > 3*1024*1024: | |
urlpath = filepath.replace(".txt", ".html") | |
dzg = f"http://www.daizhige.org/{urlpath}" | |
st.markdown(f"文件太大,[前往殆知阁页面]({dzg}), 或挑挑其他的书吧") | |
reset_path() | |
return None | |
path_text.text(filepath) | |
text = fetch_file(filepath) | |
# create markdown with max heights | |
c.markdown( | |
f"""<pre style='max-height:300px;overflow-y:auto'>{text}</pre>""", unsafe_allow_html=True | |
) | |
reset_path() | |
else: | |
sub_list = fetch_from_df( | |
st.session_state.pathway) | |
path_text.text("/".join(st.session_state.pathway)) | |
display_tree() | |
if st.session_state.translating == False: | |
display_tree() | |
def translate_text(): | |
st.session_state.translating = True | |
if c2.button("【翻译】"): | |
if cc: | |
if len(cc) > 168: | |
c2.write(f"句子太长,最多168个字符") | |
else: | |
c2.markdown(f"""```{inference(cc)}```""") | |
else: | |
c2.write("请输入文本") | |
st.session_state.translating = False | |
cc = c2.text_area("【输入文本】", height=150) | |
translate_text() | |