|
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
import gradio as gr |
|
import torch |
|
import logging |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("mrm8488/falcoder-7b") |
|
model = AutoModelForCausalLM.from_pretrained("mrm8488/falcoder-7b") |
|
|
|
def generate_text(prompt, max_length, do_sample, temperature, top_k, top_p): |
|
""" |
|
Generates text completion given a prompt and specified parameters. |
|
|
|
:param prompt: Input prompt for text generation. |
|
:type prompt: str |
|
:param max_length: Maximum length of generated text. |
|
:type max_length: int |
|
:param do_sample: Whether to use sampling for text generation. |
|
:type do_sample: bool |
|
:param temperature: Sampling temperature for text generation. |
|
:type temperature: float |
|
:param top_k: Value for top-k sampling. |
|
:type top_k: int |
|
:param top_p: Value for top-p sampling. |
|
:type top_p: float |
|
:return: Generated text completion. |
|
:rtype: str |
|
""" |
|
|
|
|
|
formatted_prompt = "\n" + prompt |
|
if not ',' in prompt: |
|
formatted_prompt += ',' |
|
|
|
|
|
prompt = tokenizer(formatted_prompt, return_tensors='pt') |
|
prompt = {key: value.to(device) for key, value in prompt.items()} |
|
|
|
|
|
out = model.generate(**prompt, max_length=max_length, do_sample=do_sample, temperature=temperature, |
|
no_repeat_ngram_size=3, top_k=top_k, top_p=top_p) |
|
output = tokenizer.decode(out[0]) |
|
clean_output = output.replace('\n', '\n') |
|
|
|
|
|
logger.info("Text generated: %s", clean_output) |
|
|
|
return clean_output |
|
|
|
|
|
custom_css = """ |
|
.gradio-container { |
|
background-color: #0D1525; |
|
color:white |
|
} |
|
#orange-button { |
|
background: #F26207 !important; |
|
color: white; |
|
} |
|
.cm-gutters{ |
|
border: none !important; |
|
} |
|
""" |
|
|
|
def post_processing(prompt, completion): |
|
""" |
|
Formats generated text completion for display. |
|
|
|
:param prompt: Input prompt for text generation. |
|
:type prompt: str |
|
:param completion: Generated text completion. |
|
:type completion: str |
|
:return: Formatted text completion. |
|
:rtype: str |
|
""" |
|
return prompt + completion |
|
|
|
def code_generation(prompt, max_new_tokens, temperature=0.2, seed=42, top_p=0.9, top_k=None, use_cache=True, repetition_penalty=1.0): |
|
""" |
|
Generates code completion given a prompt and specified parameters. |
|
|
|
:param prompt: Input prompt for code generation. |
|
:type prompt: str |
|
:param max_new_tokens: Maximum number of tokens to generate. |
|
:type max_new_tokens: int |
|
:param temperature: Sampling temperature for code generation. |
|
:type temperature: float |
|
:param seed: Random seed for code generation. |
|
:type seed: int |
|
:param top_p: Value for top-p sampling. |
|
:type top_p: float |
|
:param top_k: Value for top-k sampling. |
|
:type top_k: int |
|
:param use_cache: Whether to use cache for code generation. |
|
:type use_cache: bool |
|
:param repetition_penalty: Value for repetition penalty. |
|
:type repetition_penalty: float |
|
:return: Generated code completion. |
|
:rtype: str |
|
""" |
|
|
|
|
|
MAX_INPUT_TOKENS = 2048 |
|
if len(prompt) > MAX_INPUT_TOKENS: |
|
prompt = prompt[-MAX_INPUT_TOKENS:] |
|
|
|
|
|
x = tokenizer.encode(prompt, return_tensors="pt", max_length=MAX_INPUT_TOKENS, truncation=True).to(device) |
|
logger.info("Prompt shape: %s", x.shape) |
|
|
|
|
|
set_seed(seed) |
|
y = model.generate(x, |
|
max_new_tokens=max_new_tokens, |
|
temperature=temperature, |
|
pad_token_id=tokenizer.pad_token_id, |
|
eos_token_id=tokenizer.eos_token_id, |
|
top_p=top_p, |
|
top_k=top_k, |
|
use_cache=use_cache, |
|
repetition_penalty=repetition_penalty |
|
) |
|
completion = tokenizer.decode(y[0], skip_special_tokens=True, clean_up_tokenization_spaces=False) |
|
completion = completion[len(prompt):] |
|
|
|
return post_processing(prompt, completion) |
|
|
|
description = """ |
|
### Falcoder |
|
|
|
Falcoder is a GPT-2 model fine-tuned on Python code. It can be used for generating code completions given a prompt. |
|
|
|
### Text Generation |
|
|
|
Use the text generation section to generate text completions given a prompt. You can adjust the maximum length of the generated text, whether to use sampling, the sampling temperature, and the top-k and top-p values for sampling. |
|
|
|
### Code Generation |
|
|
|
Use the code generation section to generate code completions given a prompt. You can adjust the maximum number of tokens to generate, the sampling temperature, the random seed, the top-p and top-k values for sampling, whether to use cache, and the repetition penalty. |
|
""" |
|
|
|
demo = gr.Interface( |
|
[generate_text, code_generation], |
|
["textbox", "textbox"], |
|
["textbox", "textbox"], |
|
title="Falcoder", |
|
description=description, |
|
theme="compact", |
|
layout="vertical", |
|
css=custom_css |
|
) |
|
|
|
|
|
demo.launch() |