adarshj322 commited on
Commit
9718b31
·
1 Parent(s): 5b8ba96

Initial Commit

Browse files
Files changed (2) hide show
  1. app.py +58 -0
  2. requirements.txt +0 -0
app.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import wandb
2
+ import torch
3
+ import re
4
+
5
+ import gradio
6
+
7
+ from transformers import GPT2Tokenizer,GPT2LMHeadModel
8
+
9
+ wandb.login()
10
+
11
+ run = wandb.init(project="Question_Answer", job_type="model_loading", id='xeew4vz7', resume="must")
12
+
13
+ artifact = run.use_artifact('Question_Answer/final_model_QA:v0')
14
+
15
+ #artifact = run.use_artifact('enron-subgen-gpt2/model-1hhufzjv:v0')
16
+ # Download the artifact to a directory
17
+ artifact_dir = artifact.download()
18
+
19
+ MODEL_KEY = 'distilgpt2'
20
+ tokenizer= GPT2Tokenizer.from_pretrained(MODEL_KEY)
21
+ tokenizer.add_special_tokens({'pad_token':'{PAD}'})
22
+
23
+ model = GPT2LMHeadModel.from_pretrained(artifact_dir)
24
+ model.resize_token_embeddings(len(tokenizer))
25
+
26
+ def clean_text(text):
27
+ # Lowercase the text
28
+
29
+ res = re.sub(r'\d', '', text)
30
+
31
+ text = text.lower()
32
+ # Remove special characters
33
+ text = re.sub(r'\W', ' ', text)
34
+ # Remove extra white spaces
35
+ text = re.sub(r'\s+', ' ', text).strip()
36
+ return text
37
+
38
+ def generateAnswer(question){
39
+
40
+ question = "<question>" + clean_text(question) + "<answer>"
41
+
42
+ prompt = []
43
+ prompt.append(question)
44
+
45
+ prompts_batch_ids = tokenizer(prompt,
46
+ padding=True, truncation=True, return_tensors='pt').to(model.device)
47
+ output_ids = model.generate(
48
+ **prompts_batch_ids, max_new_tokens=50,
49
+ pad_token_id=tokenizer.pad_token_id)
50
+ outputs_batch = [seq.split('<answer>')[1] for seq in
51
+ tokenizer.batch_decode(output_ids, skip_special_tokens=True)]
52
+ print(outputs_batch)
53
+ return outputs_batch[0]
54
+ }
55
+
56
+
57
+ iface = gradio.Interface(fn=generateAnswer, inputs="text", outputs="text")
58
+ iface.launch()
requirements.txt ADDED
File without changes