gufett0 commited on
Commit
f797fbc
·
1 Parent(s): f7aeb1e

added new class

Browse files
Files changed (2) hide show
  1. backend.py +0 -1
  2. interface.py +10 -7
backend.py CHANGED
@@ -38,7 +38,6 @@ Settings.llm = GemmaLLMInterface(model=model, tokenizer=tokenizer)"""
38
  Settings.embed_model = InstructorEmbedding(model_name="hkunlp/instructor-base")
39
  Settings.llm = GemmaLLMInterface(model_id="google/gemma-2-2b-it")
40
 
41
-
42
  ############################---------------------------------
43
 
44
  # Get the parser
 
38
  Settings.embed_model = InstructorEmbedding(model_name="hkunlp/instructor-base")
39
  Settings.llm = GemmaLLMInterface(model_id="google/gemma-2-2b-it")
40
 
 
41
  ############################---------------------------------
42
 
43
  # Get the parser
interface.py CHANGED
@@ -8,14 +8,17 @@ from threading import Thread
8
 
9
  # for transformers 2
10
  class GemmaLLMInterface(CustomLLM):
11
- def __init__(self, model_id: str = "google/gemma-2-2b-it", context_window: int = 8192, num_output: int = 2048):
12
- self.model_id = model_id
13
- self.context_window = context_window
14
- self.num_output = num_output
15
-
16
- self.tokenizer = AutoTokenizer.from_pretrained(model_id)
 
 
 
17
  self.model = AutoModelForCausalLM.from_pretrained(
18
- model_id,
19
  device_map="auto",
20
  torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
21
  )
 
8
 
9
  # for transformers 2
10
  class GemmaLLMInterface(CustomLLM):
11
+ model_id: str = Field(default="google/gemma-2-2b-it")
12
+ context_window: int = Field(default=8192)
13
+ num_output: int = Field(default=2048)
14
+ tokenizer: Any = Field(default=None)
15
+ model: Any = Field(default=None)
16
+
17
+ def __init__(self, **data):
18
+ super().__init__(**data)
19
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
20
  self.model = AutoModelForCausalLM.from_pretrained(
21
+ self.model_id,
22
  device_map="auto",
23
  torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
24
  )