filipealmeida's picture
Try to load the model in 8 bits
551ba9b unverified
raw
history blame
1.29 kB
import gradio as gr
from transformers import pipeline
import logging
import re
# Set up logging
logging.basicConfig(level=logging.INFO)
logging.getLogger('transformers').setLevel(logging.INFO)
llama = pipeline("text-generation", model="filipealmeida/open-llama-3b-v2-pii-transform", model_kwargs={"load_in_8bit": True})
def generate_text(prompt, example):
logging.debug(f"Received prompt: {prompt}")
input = f"""
### Instruction:
{prompt}
### Response:
"""
logging.info(f"Input : {input}")
output = llama(input, max_length=70)
generated_text = output[0]["generated_text"]
logging.info(f"Generated text: {generated_text}")
match = re.search("### Response:\n(.*?)\n", generated_text, re.DOTALL)
parsed_text = "ERROR"
if match:
parsed_text = match.group(1).strip()
else:
print("No matching section found.")
logging.info(f"Parsed text: {parsed_text}")
return parsed_text
# Create a Gradio interface
interface = gr.Interface(
fn=generate_text,
inputs=[
gr.Textbox(lines=1, placeholder="Enter text to anonimize...", label="Prompt",
value="My name is Filipe and my phone number is 555-121-2234. How are you?")
],
outputs=gr.Textbox(label="Generated text")
)
# Launch the interface
interface.launch()