namnh113 commited on
Commit
09234a0
·
1 Parent(s): 9f50780

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +149 -0
app.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ from langchain.llms import HuggingFaceHub
4
+ from models import return_sum_models
5
+
6
+ class LLM_Langchain():
7
+ def __init__(self):
8
+ st.header('🦜 Code summarization')
9
+ st.warning("Warning: input function needs cleaning and may take long to be processed at first time")
10
+ st.info("Reference: [CodeT5](https://arxiv.org/abs/2109.00859), [The Vault](https://arxiv.org/abs/2305.06156), [CodeXGLUE](https://arxiv.org/abs/2102.04664)")
11
+ st.info("About me: namnh113")
12
+
13
+
14
+ self.API_KEY = st.sidebar.text_input(
15
+ 'API key (not necessary for now)',
16
+ type='password',
17
+ help="Type in your HuggingFace API key to use this app")
18
+
19
+
20
+ model_parent = st.sidebar.selectbox(
21
+ label = "Choose language",
22
+ options = ["python", "java", "javascript", "php", "ruby", "go", "cpp"],
23
+ help="Choose languages",
24
+ )
25
+
26
+ if model_parent is None:
27
+ model_name_visibility = True
28
+ else:
29
+ model_name_visibility = False
30
+
31
+ model_name = return_sum_models(model_parent)
32
+ list_model = [model_name]
33
+ if model_parent == "python":
34
+ list_model += [model_name+"_v2"]
35
+ if model_parent != "cpp":
36
+ list_model += ["Salesforce/codet5-base-multi-sum", f"Salesforce/codet5-base-codexglue-sum-{model_parent}"]
37
+
38
+ self.checkpoint = st.sidebar.selectbox(
39
+ label = "Choose model (nam194/... is my model)",
40
+ options = list_model,
41
+ help="Model used to predict",
42
+ disabled=model_name_visibility
43
+ )
44
+
45
+ self.max_new_tokens = st.sidebar.slider(
46
+ label="Token Length",
47
+ min_value=32,
48
+ max_value=1024,
49
+ step=32,
50
+ value=64,
51
+ help="Set the max tokens to get accurate results"
52
+ )
53
+
54
+ self.num_beams = st.sidebar.slider(
55
+ label="num beams",
56
+ min_value=1,
57
+ max_value=10,
58
+ step=1,
59
+ value=4,
60
+ help="Set num beam"
61
+ )
62
+
63
+ self.top_k = st.sidebar.slider(
64
+ label="top k",
65
+ min_value=1,
66
+ max_value=50,
67
+ step=1,
68
+ value=30,
69
+ help="Set the top_k"
70
+ )
71
+
72
+ self.top_p = st.sidebar.slider(
73
+ label="top p",
74
+ min_value=0.1,
75
+ max_value=1.0,
76
+ step=0.05,
77
+ value=0.95,
78
+ help="Set the top_p"
79
+ )
80
+
81
+
82
+ self.model_kwargs = {
83
+ "max_new_tokens": self.max_new_tokens,
84
+ "top_k": self.top_k,
85
+ "top_p": self.top_p,
86
+ "num_beams": self.num_beams
87
+ }
88
+
89
+ os.environ['HUGGINGFACEHUB_API_TOKEN'] = self.API_KEY
90
+
91
+
92
+ def generate_response(self, input_text):
93
+
94
+
95
+ llm = HuggingFaceHub(
96
+ repo_id = self.checkpoint,
97
+ model_kwargs = self.model_kwargs
98
+ )
99
+
100
+ return llm(input_text)
101
+
102
+
103
+
104
+ def form_data(self):
105
+ # with st.form('my_form'):
106
+ try:
107
+ if not self.API_KEY.startswith('hf_'):
108
+ st.warning('Please enter your API key!', icon='⚠')
109
+
110
+
111
+ if "messages" not in st.session_state:
112
+ st.session_state.messages = []
113
+
114
+ st.write(f"You are using {self.checkpoint} model")
115
+
116
+ for message in st.session_state.messages:
117
+ with st.chat_message(message.get('role')):
118
+ st.write(message.get("content"))
119
+ text = st.chat_input(disabled=False)
120
+
121
+ if text:
122
+ st.session_state.messages.append(
123
+ {
124
+ "role":"user",
125
+ "content": text
126
+ }
127
+ )
128
+ with st.chat_message("user"):
129
+ st.write(text)
130
+
131
+ if text.lower() == "clear":
132
+ del st.session_state.messages
133
+ return
134
+
135
+ result = self.generate_response(text)
136
+ st.session_state.messages.append(
137
+ {
138
+ "role": "assistant",
139
+ "content": result
140
+ }
141
+ )
142
+ with st.chat_message('assistant'):
143
+ st.markdown(result)
144
+
145
+ except Exception as e:
146
+ st.error(e, icon="🚨")
147
+
148
+ model = LLM_Langchain()
149
+ model.form_data()