Not-Grim-Refer's picture
Update app.py
c12c1d4
raw
history blame
5.53 kB
# Import necessary libraries
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
import torch
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Set device to GPU if available, otherwise CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("mrm8488/falcoder-7b")
model = AutoModelForCausalLM.from_pretrained("mrm8488/falcoder-7b")
def generate_text(prompt, max_length, do_sample, temperature, top_k, top_p):
"""
Generates text completion given a prompt and specified parameters.
:param prompt: Input prompt for text generation.
:type prompt: str
:param max_length: Maximum length of generated text.
:type max_length: int
:param do_sample: Whether to use sampling for text generation.
:type do_sample: bool
:param temperature: Sampling temperature for text generation.
:type temperature: float
:param top_k: Value for top-k sampling.
:type top_k: int
:param top_p: Value for top-p sampling.
:type top_p: float
:return: Generated text completion.
:rtype: str
"""
# Format prompt
formatted_prompt = "\n" + prompt
if not ',' in prompt:
formatted_prompt += ','
# Tokenize prompt and move to device
prompt = tokenizer(formatted_prompt, return_tensors='pt')
prompt = {key: value.to(device) for key, value in prompt.items()}
# Generate text completion using model and specified parameters
out = model.generate(**prompt, max_length=max_length, do_sample=do_sample, temperature=temperature,
no_repeat_ngram_size=3, top_k=top_k, top_p=top_p)
output = tokenizer.decode(out[0])
clean_output = output.replace('\n', '\n')
# Log generated text completion
logger.info("Text generated: %s", clean_output)
return clean_output
# Define Gradio interface
custom_css = """
.gradio-container {
background-color: #0D1525;
color:white
}
#orange-button {
background: #F26207 !important;
color: white;
}
.cm-gutters{
border: none !important;
}
"""
def post_processing(prompt, completion):
"""
Formats generated text completion for display.
:param prompt: Input prompt for text generation.
:type prompt: str
:param completion: Generated text completion.
:type completion: str
:return: Formatted text completion.
:rtype: str
"""
return prompt + completion
def code_generation(prompt, max_new_tokens, temperature=0.2, seed=42, top_p=0.9, top_k=None, use_cache=True, repetition_penalty=1.0):
"""
Generates code completion given a prompt and specified parameters.
:param prompt: Input prompt for code generation.
:type prompt: str
:param max_new_tokens: Maximum number of tokens to generate.
:type max_new_tokens: int
:param temperature: Sampling temperature for code generation.
:type temperature: float
:param seed: Random seed for code generation.
:type seed: int
:param top_p: Value for top-p sampling.
:type top_p: float
:param top_k: Value for top-k sampling.
:type top_k: int
:param use_cache: Whether to use cache for code generation.
:type use_cache: bool
:param repetition_penalty: Value for repetition penalty.
:type repetition_penalty: float
:return: Generated code completion.
:rtype: str
"""
# Truncate prompt if too long
MAX_INPUT_TOKENS = 2048
if len(prompt) > MAX_INPUT_TOKENS:
prompt = prompt[-MAX_INPUT_TOKENS:]
# Tokenize prompt and move to device
x = tokenizer.encode(prompt, return_tensors="pt", max_length=MAX_INPUT_TOKENS, truncation=True).to(device)
logger.info("Prompt shape: %s", x.shape)
# Generate code completion using model and specified parameters
set_seed(seed)
y = model.generate(x,
max_new_tokens=max_new_tokens,
temperature=temperature,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
top_p=top_p,
top_k=top_k,
use_cache=use_cache,
repetition_penalty=repetition_penalty
)
completion = tokenizer.decode(y[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
completion = completion[len(prompt):]
return post_processing(prompt, completion)
description = """
### Falcoder
Falcoder is a GPT-2 model fine-tuned on Python code. It can be used for generating code completions given a prompt.
### Text Generation
Use the text generation section to generate text completions given a prompt. You can adjust the maximum length of the generated text, whether to use sampling, the sampling temperature, and the top-k and top-p values for sampling.
### Code Generation
Use the code generation section to generate code completions given a prompt. You can adjust the maximum number of tokens to generate, the sampling temperature, the random seed, the top-p and top-k values for sampling, whether to use cache, and the repetition penalty.
"""
demo = gr.Interface(
[generate_text, code_generation],
["textbox", "textbox"],
["textbox", "textbox"],
title="Falcoder",
description=description,
theme="compact",
layout="vertical",
css=custom_css
)
# Launch Gradio interface
demo.launch()