CogwiseAI commited on
Commit
855620f
·
1 Parent(s): 2dbbece

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -236
app.py CHANGED
@@ -1,258 +1,72 @@
1
- import streamlit as st
2
- import uuid
3
- import sys
4
- import requests
5
- import pandas as pd
6
- from peft import *
7
- import bitsandbytes as bnb
8
- import pandas as pd
9
  import torch
10
- import torch.nn as nn
11
- import transformers
12
- from datasets import load_dataset
13
- from huggingface_hub import notebook_login
14
- from peft import (
15
- LoraConfig,
16
- PeftConfig,
17
- get_peft_model,
18
- prepare_model_for_kbit_training,
19
- )
20
- from transformers import (
21
- AutoConfig,
22
- AutoModelForCausalLM,
23
- AutoTokenizer,
24
- BitsAndBytesConfig,
25
- )
26
-
27
-
28
- USER_ICON = "images/user-icon.png"
29
- AI_ICON = "images/ai-icon.png"
30
- MAX_HISTORY_LENGTH = 5
31
-
32
- if 'user_id' in st.session_state:
33
- user_id = st.session_state['user_id']
34
- else:
35
- user_id = str(uuid.uuid4())
36
- st.session_state['user_id'] = user_id
37
-
38
- if 'chat_history' not in st.session_state:
39
- st.session_state['chat_history'] = []
40
-
41
- if "chats" not in st.session_state:
42
- st.session_state.chats = [
43
- {
44
- 'id': 0,
45
- 'question': '',
46
- 'answer': ''
47
- }
48
- ]
49
-
50
- if "questions" not in st.session_state:
51
- st.session_state.questions = []
52
-
53
- if "answers" not in st.session_state:
54
- st.session_state.answers = []
55
-
56
- if "input" not in st.session_state:
57
- st.session_state.input = ""
58
-
59
- st.markdown("""
60
- <style>
61
- .block-container {
62
- padding-top: 32px;
63
- padding-bottom: 32px;
64
- padding-left: 0;
65
- padding-right: 0;
66
- }
67
- .element-container img {
68
- background-color: #000000;
69
- }
70
-
71
- .main-header {
72
- font-size: 24px;
73
- }
74
- </style>
75
- """, unsafe_allow_html=True)
76
-
77
- def write_top_bar():
78
- col1, col2, col3 = st.columns([1,10,2])
79
- with col1:
80
- st.image(AI_ICON, use_column_width='always')
81
- with col2:
82
- header = "Cogwise Intelligent Assistant"
83
- st.write(f"<h3 class='main-header'>{header}</h3>", unsafe_allow_html=True)
84
- with col3:
85
- clear = st.button("Clear Chat")
86
- return clear
87
-
88
- clear = write_top_bar()
89
-
90
- if clear:
91
- st.session_state.questions = []
92
- st.session_state.answers = []
93
- st.session_state.input = ""
94
- st.session_state["chat_history"] = []
95
-
96
- def handle_input():
97
- input = st.session_state.input
98
- question_with_id = {
99
- 'question': input,
100
- 'id': len(st.session_state.questions)
101
- }
102
- st.session_state.questions.append(question_with_id)
103
 
104
- chat_history = st.session_state["chat_history"]
105
- if len(chat_history) == MAX_HISTORY_LENGTH:
106
- chat_history = chat_history[:-1]
107
-
108
- # api_url = "https://9pl792yjf9.execute-api.us-east-1.amazonaws.com/beta/chatcogwise"
109
- # api_request_data = {"question": input, "session": user_id}
110
- # api_response = requests.post(api_url, json=api_request_data)
111
- # result = api_response.json()
112
-
113
- # answer = result['answer']
114
- # !pip install -Uqqq pip --progress-bar off
115
- # !pip install -qqq bitsandbytes == 0.39.0
116
- # !pip install -qqqtorch --2.0.1 --progress-bar off
117
- # !pip install -qqq -U git + https://github.com/huggingface/transformers.git@e03a9cc --progress-bar off
118
- # !pip install -qqq -U git + https://github.com/huggingface/peft.git@42a184f --progress-bar off
119
- # !pip install -qqq -U git + https://github.com/huggingface/accelerate.git@c9fbb71 --progress-bar off
120
- # !pip install -qqq datasets == 2.12.0 --progress-bar off
121
- # !pip install -qqq loralib == 0.1.1 --progress-bar off
122
- # !pip install einops
123
-
124
- import os
125
- # from pprint import pprint
126
- # import json
127
-
128
-
129
-
130
- os.environ["CUDA_VISIBLE_DEVICES"] = "0"
131
-
132
- # notebook_login()
133
- # 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
134
-
135
- # """### Load dataset"""
136
-
137
- from datasets import load_dataset
138
-
139
- dataset_name = "nisaar/Lawyer_GPT_India"
140
- # dataset_name = "patrick11434/TEST_LLM_DATASET"
141
- dataset = load_dataset(dataset_name, split="train")
142
-
143
- # """## Load adapters from the Hub
144
-
145
- # You can also directly load adapters from the Hub using the commands below:
146
- # """
147
 
148
 
149
- # change peft_model_id
150
- bnb_config = BitsAndBytesConfig(
151
- load_in_4bit=True,
152
- load_4bit_use_double_quant=True,
153
- bnb_4bit_quant_type="nf4",
154
- bnb_4bit_compute_dtype=torch.bfloat16,
155
- )
156
 
157
- peft_model_id = "nisaar/falcon7b-Indian_Law_150Prompts"
158
- config = PeftConfig.from_pretrained(peft_model_id)
159
- model = AutoModelForCausalLM.from_pretrained(
160
- config.base_model_name_or_path,
161
- return_dict=True,
162
- quantization_config=bnb_config,
163
- device_map="auto",
164
- trust_remote_code=True,
165
  )
166
- tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
167
- tokenizer.pad_token = tokenizer.eos_token
168
-
169
- model = PeftModel.from_pretrained(model, peft_model_id)
170
 
171
- """## Inference
 
172
 
173
- You can then directly use the trained model or the model that you have loaded from the 🤗 Hub for inference as you would do it usually in `transformers`.
174
- """
 
175
 
176
- generation_config = model.generation_config
177
- generation_config.max_new_tokens = 200
178
- generation_config_temperature = 1
179
- generation_config.top_p = 0.7
180
- generation_config.num_return_sequences = 1
181
- generation_config.pad_token_id = tokenizer.eos_token_id
182
- generation_config_eod_token_id = tokenizer.eos_token_id
183
 
184
- DEVICE = "cuda:0"
185
 
186
- # Commented out IPython magic to ensure Python compatibility.
187
- # %%time
188
- # prompt = f"""
189
- # <human>: Who appoints the Chief Justice of India?
190
- # <assistant>:
191
- # """.strip()
192
- #
193
- # encoding = tokenizer(prompt, return_tensors="pt").to(DEVICE)
194
- # with torch.inference_mode():
195
- # outputs = model.generate(
196
- # input_ids=encoding.attention_mask,
197
- # generation_config=generation_config,
198
- # )
199
- # print(tokenizer.decode(outputs[0],skip_special_tokens=True))
200
 
201
- def generate_response(question: str) -> str:
202
- prompt = f"""
203
- <human>: {question}
204
- <assistant>:
205
- """.strip()
206
- encoding = tokenizer(prompt, return_tensors="pt").to(DEVICE)
207
- with torch.inference_mode():
208
- outputs = model.generate(
209
- input_ids=encoding.input_ids,
210
- attention_mask=encoding.attention_mask,
211
- generation_config=generation_config,
212
- )
213
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
214
 
215
- assistant_start = '<assistant>:'
216
- response_start = response.find(assistant_start)
217
- return response[response_start + len(assistant_start):].strip()
218
 
219
- # prompt = "Debate the merits and demerits of introducing simultaneous elections in India?"
220
- prompt=input
221
- answer=print(generate_response(prompt))
222
 
223
- # answer='Yes'
224
- chat_history.append((input, answer))
225
 
226
- st.session_state.answers.append({
227
- 'answer': answer,
228
- 'id': len(st.session_state.questions)
229
- })
230
- st.session_state.input = ""
231
 
232
- def write_user_message(md):
233
- col1, col2 = st.columns([1,12])
234
 
235
- with col1:
236
- st.image(USER_ICON, use_column_width='always')
237
- with col2:
238
- st.warning(md['question'])
239
 
240
- def render_answer(answer):
241
- col1, col2 = st.columns([1,12])
242
- with col1:
243
- st.image(AI_ICON, use_column_width='always')
244
- with col2:
245
- st.info(answer)
246
 
247
- def write_chat_message(md, q):
248
- chat = st.container()
249
- with chat:
250
- render_answer(md['answer'])
 
 
 
 
251
 
252
- with st.container():
253
- for (q, a) in zip(st.session_state.questions, st.session_state.answers):
254
- write_user_message(q)
255
- write_chat_message(a, q)
256
 
257
- st.markdown('---')
258
- input = st.text_input("You are talking to an AI, ask any question.", key="input", on_change=handle_input)
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
 
 
 
 
 
 
3
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
+ model = AutoModelForCausalLM.from_pretrained(
6
+ "CogwiseAI/testchatexample",
7
+ torch_dtype=torch.bfloat16,
8
+ trust_remote_code=True,
9
+ device_map="auto",
10
+ low_cpu_mem_usage=True,
11
+ )
12
+ tokenizer = AutoTokenizer.from_pretrained("CogwiseAI/testchatexample")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
 
15
+ def generate_text(input_text):
16
+ input_ids = tokenizer.encode(input_text, return_tensors="pt")
17
+ attention_mask = torch.ones(input_ids.shape)
 
 
 
 
18
 
19
+ output = model.generate(
20
+ input_ids,
21
+ attention_mask=attention_mask,
22
+ max_length=200,
23
+ do_sample=True,
24
+ top_k=10,
25
+ num_return_sequences=1,
26
+ eos_token_id=tokenizer.eos_token_id,
27
  )
 
 
 
 
28
 
29
+ output_text = tokenizer.decode(output[0], skip_special_tokens=True)
30
+ print(output_text)
31
 
32
+ # Remove Prompt Echo from Generated Text
33
+ cleaned_output_text = output_text.replace(input_text, "")
34
+ return cleaned_output_text
35
 
 
 
 
 
 
 
 
36
 
 
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
 
 
 
40
 
 
 
 
41
 
 
 
42
 
 
 
 
 
 
43
 
 
 
44
 
45
+ block = gr.Blocks()
 
 
 
46
 
 
 
 
 
 
 
47
 
48
+ with block:
49
+ gr.Markdown("""<h1><center>Cogwise AI Falcon-7B Instruct</center></h1>
50
+ """)
51
+ chatbot = gr.Chatbot()
52
+ message = gr.Textbox(placeholder=prompt)
53
+ state = gr.State()
54
+ submit = gr.Button("SEND")
55
+ submit.click(generate_text, inputs=[message, state], outputs=[chatbot, state])
56
 
57
+ block.launch(debug = True)
 
 
 
58
 
59
+ # logo = (
60
+ # "<div >"
61
+ # "<img src='ai-icon.png'alt='image One'>"
62
+ # + "</div>"
63
+ # )
64
+ # text_generation_interface = gr.Interface(
65
+ # fn=generate_text,
66
+ # inputs=[
67
+ # gr.inputs.Textbox(label="Input Text"),
68
+ # ],
69
+ # outputs=gr.inputs.Textbox(label="Generated Text"),
70
+ # title="Falcon-7B Instruct",
71
+ # image=logo
72
+ # ).launch()