Spaces:
Sleeping
Sleeping
import os | |
import streamlit as st | |
from langchain.llms import HuggingFaceHub | |
from models import return_sum_models | |
class LLM_Langchain(): | |
def __init__(self): | |
st.warning("Warning: Input function need to be clean and may take long to process") | |
st.header('π¦ Code summarization with CodeT5-small + CodeXGLUE dataset') | |
st.warning("Enter your huggingface API key first !") | |
self.API_KEY = st.sidebar.text_input( | |
'API Key', | |
type='password', | |
help="Type in your HuggingFace API key to use this app") | |
languages = ["php", "java", "javascript", "python", "ruby"] | |
model_parent = st.sidebar.selectbox( | |
label = "Choose Language", | |
options = languages, | |
help="Choose languages", | |
) | |
if model_parent is None: | |
model_name_visibility = True | |
else: | |
model_name_visibility = False | |
options = return_sum_models(model_parent) | |
self.model_name = st.sidebar.selectbox( | |
label = "Models", | |
options = options, | |
help="Chosen Model to predict", | |
disabled=model_name_visibility | |
) | |
self.max_new_tokens = st.sidebar.slider( | |
label="Token Length", | |
min_value=32, | |
max_value=1024, | |
step=32, | |
value=120, | |
help="Set the max tokens to get accurate results" | |
) | |
self.top_k = st.sidebar.slider( | |
label="top_k", | |
min_value=1, | |
max_value=50, | |
step=1, | |
value=30, | |
help="Set the top_k" | |
) | |
self.top_p = st.sidebar.slider( | |
label="top_p", | |
min_value=0.1, | |
max_value=1, | |
step=0.05, | |
value=0.95, | |
help="Set the top_p" | |
) | |
self.model_kwargs = { | |
"max_new_tokens": self.max_new_tokens, | |
"top_k": self.top_k, | |
"top_p": self.top_p | |
} | |
os.environ['HUGGINGFACEHUB_API_TOKEN'] = self.API_KEY | |
def generate_response(self, input_text): | |
llm = HuggingFaceHub( | |
repo_id = self.model_name, | |
model_kwargs = self.model_kwargs | |
) | |
return llm(input_text) | |
def form_data(self): | |
# with st.form('my_form'): | |
try: | |
if not self.API_KEY.startswith('hf_'): | |
st.warning('Please enter your API key!', icon='β ') | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
st.write(f"You are using {self.model_name} model") | |
for message in st.session_state.messages: | |
with st.chat_message(message.get('role')): | |
st.write(message.get("content")) | |
text = st.chat_input(disabled=text_input_visibility) | |
if text: | |
st.session_state.messages.append( | |
{ | |
"role":"user", | |
"content": text | |
} | |
) | |
with st.chat_message("user"): | |
st.write(text) | |
if text.lower() == "clear": | |
del st.session_state.messages | |
return | |
result = self.generate_response(text) | |
st.session_state.messages.append( | |
{ | |
"role": "assistant", | |
"content": result | |
} | |
) | |
with st.chat_message('assistant'): | |
st.markdown(result) | |
except Exception as e: | |
st.error(e, icon="π¨") | |
model = LLM_Langchain() | |
model.form_data() | |