Not-Grim-Refer commited on
Commit
e936a3f
1 Parent(s): c12c1d4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -159
app.py CHANGED
@@ -1,162 +1,60 @@
1
- # Import necessary libraries
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
- import gradio as gr
4
- import torch
5
- import logging
6
-
7
- # Set up logging
8
- logging.basicConfig(level=logging.INFO)
9
- logger = logging.getLogger(__name__)
10
-
11
- # Set device to GPU if available, otherwise CPU
12
- device = "cuda" if torch.cuda.is_available() else "cpu"
13
-
14
- # Load tokenizer and model
15
- tokenizer = AutoTokenizer.from_pretrained("mrm8488/falcoder-7b")
16
- model = AutoModelForCausalLM.from_pretrained("mrm8488/falcoder-7b")
17
-
18
- def generate_text(prompt, max_length, do_sample, temperature, top_k, top_p):
19
- """
20
- Generates text completion given a prompt and specified parameters.
21
-
22
- :param prompt: Input prompt for text generation.
23
- :type prompt: str
24
- :param max_length: Maximum length of generated text.
25
- :type max_length: int
26
- :param do_sample: Whether to use sampling for text generation.
27
- :type do_sample: bool
28
- :param temperature: Sampling temperature for text generation.
29
- :type temperature: float
30
- :param top_k: Value for top-k sampling.
31
- :type top_k: int
32
- :param top_p: Value for top-p sampling.
33
- :type top_p: float
34
- :return: Generated text completion.
35
- :rtype: str
36
- """
37
-
38
- # Format prompt
39
- formatted_prompt = "\n" + prompt
40
- if not ',' in prompt:
41
- formatted_prompt += ','
42
-
43
- # Tokenize prompt and move to device
44
- prompt = tokenizer(formatted_prompt, return_tensors='pt')
45
- prompt = {key: value.to(device) for key, value in prompt.items()}
46
-
47
- # Generate text completion using model and specified parameters
48
- out = model.generate(**prompt, max_length=max_length, do_sample=do_sample, temperature=temperature,
49
- no_repeat_ngram_size=3, top_k=top_k, top_p=top_p)
50
- output = tokenizer.decode(out[0])
51
- clean_output = output.replace('\n', '\n')
52
-
53
- # Log generated text completion
54
- logger.info("Text generated: %s", clean_output)
55
-
56
- return clean_output
57
-
58
- # Define Gradio interface
59
- custom_css = """
60
- .gradio-container {
61
- background-color: #0D1525;
62
- color:white
63
- }
64
- #orange-button {
65
- background: #F26207 !important;
66
- color: white;
67
- }
68
- .cm-gutters{
69
- border: none !important;
70
- }
71
- """
72
 
73
- def post_processing(prompt, completion):
74
- """
75
- Formats generated text completion for display.
76
-
77
- :param prompt: Input prompt for text generation.
78
- :type prompt: str
79
- :param completion: Generated text completion.
80
- :type completion: str
81
- :return: Formatted text completion.
82
- :rtype: str
83
- """
84
- return prompt + completion
85
-
86
- def code_generation(prompt, max_new_tokens, temperature=0.2, seed=42, top_p=0.9, top_k=None, use_cache=True, repetition_penalty=1.0):
87
- """
88
- Generates code completion given a prompt and specified parameters.
89
-
90
- :param prompt: Input prompt for code generation.
91
- :type prompt: str
92
- :param max_new_tokens: Maximum number of tokens to generate.
93
- :type max_new_tokens: int
94
- :param temperature: Sampling temperature for code generation.
95
- :type temperature: float
96
- :param seed: Random seed for code generation.
97
- :type seed: int
98
- :param top_p: Value for top-p sampling.
99
- :type top_p: float
100
- :param top_k: Value for top-k sampling.
101
- :type top_k: int
102
- :param use_cache: Whether to use cache for code generation.
103
- :type use_cache: bool
104
- :param repetition_penalty: Value for repetition penalty.
105
- :type repetition_penalty: float
106
- :return: Generated code completion.
107
- :rtype: str
108
- """
109
-
110
- # Truncate prompt if too long
111
- MAX_INPUT_TOKENS = 2048
112
- if len(prompt) > MAX_INPUT_TOKENS:
113
- prompt = prompt[-MAX_INPUT_TOKENS:]
114
-
115
- # Tokenize prompt and move to device
116
- x = tokenizer.encode(prompt, return_tensors="pt", max_length=MAX_INPUT_TOKENS, truncation=True).to(device)
117
- logger.info("Prompt shape: %s", x.shape)
118
-
119
- # Generate code completion using model and specified parameters
120
- set_seed(seed)
121
- y = model.generate(x,
122
- max_new_tokens=max_new_tokens,
123
- temperature=temperature,
124
- pad_token_id=tokenizer.pad_token_id,
125
- eos_token_id=tokenizer.eos_token_id,
126
- top_p=top_p,
127
- top_k=top_k,
128
- use_cache=use_cache,
129
- repetition_penalty=repetition_penalty
130
- )
131
- completion = tokenizer.decode(y[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
132
- completion = completion[len(prompt):]
133
-
134
- return post_processing(prompt, completion)
135
-
136
- description = """
137
- ### Falcoder
138
-
139
- Falcoder is a GPT-2 model fine-tuned on Python code. It can be used for generating code completions given a prompt.
140
-
141
- ### Text Generation
142
-
143
- Use the text generation section to generate text completions given a prompt. You can adjust the maximum length of the generated text, whether to use sampling, the sampling temperature, and the top-k and top-p values for sampling.
144
-
145
- ### Code Generation
146
-
147
- Use the code generation section to generate code completions given a prompt. You can adjust the maximum number of tokens to generate, the sampling temperature, the random seed, the top-p and top-k values for sampling, whether to use cache, and the repetition penalty.
148
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
- demo = gr.Interface(
151
- [generate_text, code_generation],
152
- ["textbox", "textbox"],
153
- ["textbox", "textbox"],
154
- title="Falcoder",
155
- description=description,
156
- theme="compact",
157
- layout="vertical",
158
- css=custom_css
159
- )
160
-
161
- # Launch Gradio interface
162
- demo.launch()
 
1
+ import streamlit as st
2
+
3
+ st.title("Falcon QA Bot")
4
+
5
+ # import chainlit as cl
6
+
7
+ import os
8
+ huggingfacehub_api_token = st.secrets["hf_token"]
9
+
10
+ from langchain import HuggingFaceHub, PromptTemplate, LLMChain
11
+
12
+ repo_id = "tiiuae/falcon-7b-instruct"
13
+ llm = HuggingFaceHub(huggingfacehub_api_token=huggingfacehub_api_token,
14
+ repo_id=repo_id,
15
+ model_kwargs={"temperature":0.2, "max_new_tokens":2000})
16
+
17
+ template = """
18
+ You are an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.
19
+
20
+ {question}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  """
23
+ # input = st.text_input("What do you want to ask about", placeholder="Input your question here")
24
+
25
+
26
+ # # @cl.langchain_factory
27
+ # def factory():
28
+ # prompt = PromptTemplate(template=template, input_variables=['question'])
29
+ # llm_chain = LLMChain(prompt=prompt, llm=llm, verbose=True)
30
+
31
+ # return llm_chain
32
+
33
+
34
+ prompt = PromptTemplate(template=template, input_variables=["question"])
35
+ llm_chain = LLMChain(prompt=prompt,verbose=True,llm=llm)
36
+
37
+ # result = llm_chain.predict(question=input)
38
+
39
+ # print(result)
40
+
41
+ def chat(query):
42
+ # prompt = PromptTemplate(template=template, input_variables=["question"])
43
+ # llm_chain = LLMChain(prompt=prompt,verbose=True,llm=llm)
44
+
45
+ result = llm_chain.predict(question=query)
46
+
47
+ return result
48
+
49
+
50
+
51
+
52
+ def main():
53
+ input = st.text_input("What do you want to ask about", placeholder="Input your question here")
54
+ if input:
55
+ output = chat(input)
56
+ st.write(output,unsafe_allow_html=True)
57
+
58
 
59
+ if __name__ == '__main__':
60
+ main()