Spaces:
Runtime error
Runtime error
File size: 4,926 Bytes
09234a0 b2d244e 09234a0 e9493c2 09234a0 6dbfedb 09234a0 f41461e 09234a0 e898e1c 09234a0 e898e1c 09234a0 823f13e 09234a0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import os
import streamlit as st
from langchain.llms import HuggingFaceHub
from models import return_sum_models
class LLM_Langchain():
def __init__(self):
st.header('🦜 Code summarization')
st.warning("Warning: input function needs cleaning and may take long to be processed at first time")
st.warning("Note: you should not copy the whole function from IDE, the \"\\n\" character needs typing by hand")
st.info("Reference: [CodeT5](https://arxiv.org/abs/2109.00859), [The Vault](https://arxiv.org/abs/2305.06156), [CodeXGLUE](https://arxiv.org/abs/2102.04664)")
st.info("About me: namnh113")
self.API_KEY = st.sidebar.text_input(
'API key',
type='password',
help="Type in your HuggingFace API key to use this app")
model_parent = st.sidebar.selectbox(
label = "Choose language",
options = ["python", "java", "javascript", "php", "ruby", "go", "cpp"],
help="Choose languages",
)
if model_parent is None:
model_name_visibility = True
else:
model_name_visibility = False
model_name = return_sum_models(model_parent)
list_model = [model_name]
if model_parent in ["python", "java"]:
list_model += [model_name+"_v2"]
if model_parent != "cpp":
list_model += ["Salesforce/codet5-base-multi-sum", f"Salesforce/codet5-base-codexglue-sum-{model_parent}"]
self.checkpoint = st.sidebar.selectbox(
label = "Choose model (namnh113/... is my model)",
options = list_model,
help="Model used to predict",
disabled=model_name_visibility
)
self.max_new_tokens = st.sidebar.slider(
label="Token Length",
min_value=32,
max_value=248,
step=4,
value=128,
help="Set the max tokens to get accurate results"
)
self.num_beams = st.sidebar.slider(
label="num beams",
min_value=1,
max_value=10,
step=1,
value=2,
help="Set num beam"
)
self.top_k = st.sidebar.slider(
label="top k",
min_value=1,
max_value=50,
step=1,
value=30,
help="Set the top_k"
)
self.top_p = st.sidebar.slider(
label="top p",
min_value=0.1,
max_value=1.0,
step=0.05,
value=0.95,
help="Set the top_p"
)
self.model_kwargs = {
"max_new_tokens": self.max_new_tokens,
"top_k": self.top_k,
"top_p": self.top_p,
"num_beams": self.num_beams
}
os.environ['HUGGINGFACEHUB_API_TOKEN'] = self.API_KEY
def generate_response(self, input_text):
llm = HuggingFaceHub(
repo_id = self.checkpoint,
model_kwargs = self.model_kwargs
)
return llm(input_text)
def form_data(self):
# with st.form('my_form'):
try:
if not self.API_KEY.startswith('hf_'):
st.warning('Please enter your API key!', icon='⚠')
if "messages" not in st.session_state:
st.session_state.messages = []
st.write(f"You are using {self.checkpoint} model")
for message in st.session_state.messages:
with st.chat_message(message.get('role')):
st.write(message.get("content"))
text = st.chat_input(disabled=False)
if text:
st.session_state.messages.append(
{
"role":"user",
"content": text
}
)
with st.chat_message("user"):
st.write(text)
if text.lower() == "clear":
del st.session_state.messages
return
result = self.generate_response(text)
result = result.replace(' * ', '\n* ')
st.session_state.messages.append(
{
"role": "assistant",
"content": result
}
)
with st.chat_message('assistant'):
st.markdown(result)
except Exception as e:
st.error(e, icon="🚨")
model = LLM_Langchain()
model.form_data() |