import os import streamlit as st from langchain.llms import HuggingFaceHub from models import return_sum_models class LLM_Langchain(): def __init__(self): st.header('🦜 Code summarization') st.warning("Warning: input function needs cleaning, and may take long to be processed at first time") st.warning("Note: you should not copy the whole function from IDE, the \"\\n\" needs typing by hand") st.info("Reference: [CodeT5](https://arxiv.org/abs/2109.00859), [The Vault](https://arxiv.org/abs/2305.06156), [CodeXGLUE](https://arxiv.org/abs/2102.04664)") self.api_key_area = st.sidebar.text_input( 'API key (not necessary for now)', type='password', help="Type in your HuggingFace API key to use this app") self.API_KEY = os.environ["API_KEY"] self.model_parent = st.sidebar.selectbox( label = "Choose language", options = ["python", "java", "javascript", "php", "ruby", "go", "cpp"], help="Choose languages", ) if self.model_parent is None: model_name_visibility = True else: model_name_visibility = False model_name = return_sum_models(self.model_parent) list_model = [model_name] if self.model_parent in ["python", "java"]: list_model += [model_name+"_v2"] if self.model_parent != "cpp": list_model += ["Salesforce/codet5-base-multi-sum", f"Salesforce/codet5-base-codexglue-sum-{self.model_parent}"] self.checkpoint = st.sidebar.selectbox( label = "Choose model (nam194/... is my model)", options = list_model, help="Model used to predict", disabled=model_name_visibility ) self.max_new_tokens = st.sidebar.slider( label="Token Length", min_value=32, max_value=1024, step=32, value=128, help="Set the max tokens to get accurate results" ) self.num_beams = st.sidebar.slider( label="num beams", min_value=1, max_value=10, step=1, value=2, help="Set num beam" ) self.top_k = st.sidebar.slider( label="top k", min_value=1, max_value=50, step=1, value=30, help="Set the top_k" ) self.top_p = st.sidebar.slider( label="top p", min_value=0.1, max_value=1.0, step=0.05, value=0.95, help="Set the top_p" ) self.model_kwargs = { "max_new_tokens": self.max_new_tokens, "top_k": self.top_k, "top_p": self.top_p, "num_beams": self.num_beams } os.environ['HUGGINGFACEHUB_API_TOKEN'] = self.API_KEY def generate_response(self, input_text): input_text = "Summarize " + self.model_parent.capitalize() + ": " + input_text llm = HuggingFaceHub( repo_id = self.checkpoint, model_kwargs = self.model_kwargs ) return llm(input_text) def form_data(self): # with st.form('my_form'): try: if not self.API_KEY.startswith('hf_'): st.warning('Please enter your API key!', icon='⚠') if "messages" not in st.session_state: st.session_state.messages = [] st.write(f"You are using {self.checkpoint} model") for message in st.session_state.messages: with st.chat_message(message.get('role')): st.write(message.get("content")) text = st.chat_input(disabled=False) if text: st.session_state.messages.append( { "role":"user", "content": text } ) with st.chat_message("user"): st.write(text) if text.lower() == "clear": del st.session_state.messages return result = self.generate_response(text) result = result.replace(' * ', '\n* ') st.session_state.messages.append( { "role": "assistant", "content": result } ) with st.chat_message('assistant'): st.markdown(result) except Exception as e: st.error(e, icon="🚨") model = LLM_Langchain() model.form_data()