Spaces:
Sleeping
Sleeping
import wandb | |
import torch | |
import re | |
import os | |
import gradio | |
from transformers import GPT2Tokenizer,GPT2LMHeadModel | |
os.environ["WANDB_API_KEY"] = "d2ad0a7285379c0808ca816971d965fc242d0b5e" | |
wandb.login() | |
run = wandb.init(project="Question_Answer", job_type="model_loading", id='xeew4vz7', resume="must") | |
artifact = run.use_artifact('Question_Answer/final_model_QA:v0') | |
#artifact = run.use_artifact('enron-subgen-gpt2/model-1hhufzjv:v0') | |
# Download the artifact to a directory | |
artifact_dir = artifact.download() | |
MODEL_KEY = 'distilgpt2' | |
tokenizer= GPT2Tokenizer.from_pretrained(MODEL_KEY) | |
tokenizer.add_special_tokens({'pad_token':'{PAD}'}) | |
model = GPT2LMHeadModel.from_pretrained(artifact_dir) | |
model.resize_token_embeddings(len(tokenizer)) | |
def clean_text(text): | |
# Lowercase the text | |
res = re.sub(r'\d', '', text) | |
text = text.lower() | |
# Remove special characters | |
text = re.sub(r'\W', ' ', text) | |
# Remove extra white spaces | |
text = re.sub(r'\s+', ' ', text).strip() | |
return text | |
def generateAnswer(question): | |
question = "<question>" + clean_text(question) + "<answer>" | |
prompt = [] | |
prompt.append(question) | |
prompts_batch_ids = tokenizer(prompt, | |
padding=True, truncation=True, return_tensors='pt').to(model.device) | |
output_ids = model.generate( | |
**prompts_batch_ids, max_new_tokens=50, | |
pad_token_id=tokenizer.pad_token_id) | |
outputs_batch = [seq.split('<answer>')[1] for seq in | |
tokenizer.batch_decode(output_ids, skip_special_tokens=True)] | |
print(outputs_batch) | |
return outputs_batch[0] | |
iface = gradio.Interface(fn=generateAnswer, inputs="text", outputs="text") | |
iface.launch() | |