kz209 commited on
Commit
01a2ce5
1 Parent(s): 43d8095
Files changed (3) hide show
  1. README.md +1 -1
  2. pages/summarization_example.py +7 -3
  3. utils/model.py +23 -6
README.md CHANGED
@@ -53,4 +53,4 @@ Each member should create a branch with their own name and commit to that branch
53
 
54
  ### Bug Fixes and Questions
55
 
56
- For bug fixes or questions, either open an issue or create a branch prefixed with `bug` in name.
 
53
 
54
  ### Bug Fixes and Questions
55
 
56
+ For bug fixes or questions, either open an issue or create a branch prefixed with `bug` in name.
pages/summarization_example.py CHANGED
@@ -5,10 +5,14 @@ import random
5
  from utils.model import Model
6
  from utils.data import dataset
7
 
8
- __default_model_name__ = "lmsys/vicuna-7b-v1.5"
9
- model = Model(__default_model_name__)
10
  load_dotenv()
11
 
 
 
 
 
 
 
12
  random_label = '馃攢 Random dialogue from dataset'
13
  examples = {
14
  "example 1": """Boston's injury reporting for Kristaps Porzi艈模is has been fairly coy. He missed Game 3, but his coach told reporters just before Game 4 that was technically available, but with a catch.
@@ -65,7 +69,7 @@ summarization: """, label='Input Prompting Template', lines=8, placeholder='Inpu
65
  output = gr.Markdown()
66
 
67
  example_dropdown.change(update_input, inputs=[example_dropdown], outputs=[input_text])
68
- submit_button.click(process_input, inputs=[input_text, model_dropdown, Template_text], outputs=[output])
69
 
70
  return demo
71
 
 
5
  from utils.model import Model
6
  from utils.data import dataset
7
 
 
 
8
  load_dotenv()
9
 
10
+ __model_list__ = [
11
+ "lmsys/vicuna-7b-v1.5",
12
+ "tiiuae/falcon-7b-instruct"
13
+ ]
14
+ model = {model_name: Model(model_name) for model_name in __model_list__}
15
+
16
  random_label = '馃攢 Random dialogue from dataset'
17
  examples = {
18
  "example 1": """Boston's injury reporting for Kristaps Porzi艈模is has been fairly coy. He missed Game 3, but his coach told reporters just before Game 4 that was technically available, but with a catch.
 
69
  output = gr.Markdown()
70
 
71
  example_dropdown.change(update_input, inputs=[example_dropdown], outputs=[input_text])
72
+ submit_button.click(process_input, inputs=[input_text, model[model_dropdown], Template_text], outputs=[output])
73
 
74
  return demo
75
 
utils/model.py CHANGED
@@ -1,20 +1,37 @@
1
- from transformers import AutoTokenizer, AutoModelForCausalLM
2
  import transformers
3
  import torch
4
 
5
-
6
  class Model():
7
- def __init__(self, model="lmsys/vicuna-7b-v1.5") -> None:
8
- self.tokenizer = AutoTokenizer.from_pretrained(model)
 
 
 
9
  self.pipeline = transformers.pipeline(
10
  "text-generation",
11
- model=model,
12
  tokenizer=self.tokenizer,
13
  torch_dtype=torch.bfloat16,
14
  trust_remote_code=True,
15
  device_map="auto",
16
  )
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  def gen(self, content, temp=0.1, max_length=500):
19
  sequences = self.pipeline(
20
  content,
@@ -26,4 +43,4 @@ class Model():
26
  return_full_text=False
27
  )
28
 
29
- return sequences[-1]['generated_text'] #'\n'.join([seq['generated_text'] for seq in sequences])
 
1
+ from transformers import AutoTokenizer
2
  import transformers
3
  import torch
4
 
 
5
  class Model():
6
+ number_of_models = 0
7
+
8
+ def __init__(self, model_name="lmsys/vicuna-7b-v1.5") -> None:
9
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
10
+ self.name = model_name
11
  self.pipeline = transformers.pipeline(
12
  "text-generation",
13
+ model=model_name,
14
  tokenizer=self.tokenizer,
15
  torch_dtype=torch.bfloat16,
16
  trust_remote_code=True,
17
  device_map="auto",
18
  )
19
 
20
+ self.update()
21
+
22
+ @classmethod
23
+ def update(cls):
24
+ cls.number_of_models += 1
25
+
26
+ def return_mode_name(self):
27
+ return self.name
28
+
29
+ def return_tokenizer(self):
30
+ return self.tokenizer
31
+
32
+ def return_model(self):
33
+ return self.pipeline
34
+
35
  def gen(self, content, temp=0.1, max_length=500):
36
  sequences = self.pipeline(
37
  content,
 
43
  return_full_text=False
44
  )
45
 
46
+ return sequences[-1]['generated_text']