import streamlit as st import pandas as pd from pathlib import Path import requests import base64 from requests.auth import HTTPBasicAuth import torch st.set_page_config(layout="wide") @st.cache(allow_output_mutation=True) def load_model(): from transformers import ( EncoderDecoderModel, AutoTokenizer ) PRETRAINED = "raynardj/wenyanwen-ancient-translate-to-modern" tokenizer = AutoTokenizer.from_pretrained(PRETRAINED) model = EncoderDecoderModel.from_pretrained(PRETRAINED) return tokenizer, model tokenizer, model = load_model() def inference(text): tk_kwargs = dict( truncation=True, max_length=168, padding="max_length", return_tensors='pt') inputs = tokenizer([text, ], **tk_kwargs) with torch.no_grad(): return tokenizer.batch_decode( model.generate( inputs.input_ids, attention_mask=inputs.attention_mask, num_beams=3, max_length=256, bos_token_id=101, eos_token_id=tokenizer.sep_token_id, pad_token_id=tokenizer.pad_token_id, ), skip_special_tokens=True)[0].replace(" ", "") @st.cache def get_file_df(): file_df = pd.read_csv("meta.csv") return file_df file_df = get_file_df() st.sidebar.title("【随无涯】") st.sidebar.markdown(""" * 朕亲自下厨的[🤗 翻译模型](https://github.com/raynardj/wenyanwen-ancient-translate-to-modern), [⭐️ 训练笔记](https://github.com/raynardj/yuan) * 📚 书籍来自 [殆知阁](http://www.daizhige.org/),文本的[github api](https://github.com/garychowcmu/daizhigev20) """) c2 = st.container() c = st.container() USER_ID = st.secrets["USER_ID"] SECRET = st.secrets["SECRET"] @st.cache def get_maps(): file_obj_hash_map = dict(file_df[["filepath", "obj_hash"]].values) file_size_map = dict(file_df[["filepath", "fsize"]].values) return file_obj_hash_map, file_size_map file_obj_hash_map, file_size_map = get_maps() def show_file_size(size: int): if size < 1024: return f"{size} B" elif size < 1024*1024: return f"{size//1024} KB" else: return f"{size/1024//1024} MB" def fetch_file(path): # reading from local path first if (Path("data")/path).exists(): with open(Path("data")/path, "r") as f: return f.read() # read from github api obj_hash = file_obj_hash_map[path] auth = HTTPBasicAuth(USER_ID, SECRET) url = f"https://api.github.com/repos/garychowcmu/daizhigev20/git/blobs/{obj_hash}" r = requests.get(url, auth=auth) if r.status_code == 200: data = r.json() content = base64.b64decode(data['content']).decode('utf-8') return content else: r.raise_for_status() def fetch_from_df(sub_paths: str = ""): sub_df = file_df.copy() for idx, step in enumerate(sub_paths): sub_df.query(f"col_{idx} == '{step}'", inplace=True) if len(sub_df) == 0: return None return list(sub_df[f"col_{len(sub_paths)}"].unique()) # root_data = fetch_from_github() if 'pathway' in st.session_state: pass else: st.session_state.pathway = [] path_text = st.sidebar.text("/".join(st.session_state.pathway)) def reset_path(): st.session_state.pathway = [] path_text.text(st.session_state.pathway) if st.sidebar.button("回到根目录"): reset_path() st.session_state.translating = False def display_tree(): sublist = fetch_from_df(st.session_state.pathway) dropdown = st.sidebar.selectbox("【选书】", options=sublist) with st.spinner("加载中..."): st.session_state.pathway.append(dropdown) if dropdown.endswith('.txt'): filepath = "/".join(st.session_state.pathway) file_size = file_size_map[filepath] with st.spinner(f"loading file:{filepath},({show_file_size(file_size)})"): # if file size is too large, we will not load it if file_size > 3*1024*1024: urlpath = filepath.replace(".txt", ".html") dzg = f"http://www.daizhige.org/{urlpath}" st.markdown(f"文件太大,[前往殆知阁页面]({dzg}), 或挑挑其他的书吧") reset_path() return None path_text.text(filepath) text = fetch_file(filepath) # create markdown with max heights c.markdown( f"""
{text}""", unsafe_allow_html=True ) reset_path() else: sub_list = fetch_from_df( st.session_state.pathway) path_text.text("/".join(st.session_state.pathway)) display_tree() if st.session_state.translating == False: display_tree() def translate_text(): st.session_state.translating = True if c2.button("【翻译】"): if cc: if len(cc) > 168: c2.write(f"句子太长,最多168个字符") else: c2.markdown(f"""```{inference(cc)}```""") else: c2.write("请输入文本") st.session_state.translating = False cc = c2.text_area("【输入文本】", height=150) translate_text()