Spaces:
Sleeping
Sleeping
File size: 5,061 Bytes
455b92e e002a1d ee0adb2 76d7b2d e002a1d 768da65 22a6156 455b92e 5f73f54 768da65 3596135 455b92e 5f73f54 455b92e 5f73f54 e002a1d 7fb8736 e002a1d 5f73f54 d3e8b15 768da65 e002a1d 0e9a542 e420bbe 0e9a542 455b92e 7aa0c88 455b92e c03594b 455b92e c03594b 455b92e c03594b 455b92e aeaad97 455b92e c03594b 455b92e 5f73f54 455b92e d3e8b15 455b92e bafde03 455b92e 0e9a542 455b92e 99434df 455b92e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import os
import streamlit as st
from langchain.llms import HuggingFaceHub
from models import return_sum_models
class LLM_Langchain():
def __init__(self):
st.header('🦜 Code summarization')
st.warning("Warning: input function needs cleaning, and may take long to be processed at first time")
st.warning("Note: you should not copy the whole function from IDE, the \"\\n\" needs typing by hand")
st.info("Reference: [CodeT5](https://arxiv.org/abs/2109.00859), [The Vault](https://arxiv.org/abs/2305.06156), [CodeXGLUE](https://arxiv.org/abs/2102.04664)")
self.api_key_area = st.sidebar.text_input(
'API key (not necessary for now)',
type='password',
help="Type in your HuggingFace API key to use this app")
self.API_KEY = os.environ["API_KEY"]
self.model_parent = st.sidebar.selectbox(
label = "Choose language",
options = ["python", "java", "javascript", "php", "ruby", "go", "cpp"],
help="Choose languages",
)
if self.model_parent is None:
model_name_visibility = True
else:
model_name_visibility = False
model_name = return_sum_models(self.model_parent)
list_model = [model_name]
if self.model_parent in ["python", "java"]:
list_model += [model_name+"_v2"]
if self.model_parent != "cpp":
list_model += ["Salesforce/codet5-base-multi-sum", f"Salesforce/codet5-base-codexglue-sum-{self.model_parent}"]
self.checkpoint = st.sidebar.selectbox(
label = "Choose model (nam194/... is my model)",
options = list_model,
help="Model used to predict",
disabled=model_name_visibility
)
self.max_new_tokens = st.sidebar.slider(
label="Token Length",
min_value=32,
max_value=1024,
step=32,
value=64,
help="Set the max tokens to get accurate results"
)
self.num_beams = st.sidebar.slider(
label="num beams",
min_value=1,
max_value=10,
step=1,
value=4,
help="Set num beam"
)
self.top_k = st.sidebar.slider(
label="top k",
min_value=1,
max_value=50,
step=1,
value=30,
help="Set the top_k"
)
self.top_p = st.sidebar.slider(
label="top p",
min_value=0.1,
max_value=1.0,
step=0.05,
value=0.95,
help="Set the top_p"
)
self.model_kwargs = {
"max_new_tokens": self.max_new_tokens,
"top_k": self.top_k,
"top_p": self.top_p,
"num_beams": self.num_beams
}
os.environ['HUGGINGFACEHUB_API_TOKEN'] = self.API_KEY
def generate_response(self, input_text):
input_text = "Summarize " + self.model_parent.capitalize() + ": " + input_text
llm = HuggingFaceHub(
repo_id = self.checkpoint,
model_kwargs = self.model_kwargs
)
return llm(input_text)
def form_data(self):
# with st.form('my_form'):
try:
if not self.API_KEY.startswith('hf_'):
st.warning('Please enter your API key!', icon='⚠')
if "messages" not in st.session_state:
st.session_state.messages = []
st.write(f"You are using {self.checkpoint} model")
for message in st.session_state.messages:
with st.chat_message(message.get('role')):
st.write(message.get("content"))
text = st.chat_input(disabled=False)
if text:
st.session_state.messages.append(
{
"role":"user",
"content": text
}
)
with st.chat_message("user"):
st.write(text)
if text.lower() == "clear":
del st.session_state.messages
return
result = self.generate_response(text)
result = result.replace(' * ', '\n* ')
st.session_state.messages.append(
{
"role": "assistant",
"content": result
}
)
with st.chat_message('assistant'):
st.markdown(result)
except Exception as e:
st.error(e, icon="🚨")
model = LLM_Langchain()
model.form_data()
|