Zwea Htet commited on
Commit
8f095e3
·
1 Parent(s): a550aaa

update llama custom

Browse files
Files changed (1) hide show
  1. models/llamaCustom.py +2 -22
models/llamaCustom.py CHANGED
@@ -20,7 +20,6 @@ from llama_index import (
20
  StorageContext,
21
  load_index_from_storage,
22
  )
23
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
24
  from llama_index.llms import CompletionResponse, CustomLLM, LLMMetadata
25
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
26
 
@@ -58,27 +57,6 @@ def load_model(model_name: str):
58
  return pipe
59
 
60
 
61
- @st.cache_resource
62
- def load_model(mode_name: str):
63
- # llm_model_name = "bigscience/bloom-560m"
64
- tokenizer = AutoTokenizer.from_pretrained(mode_name)
65
- model = AutoModelForCausalLM.from_pretrained(mode_name, config="T5Config")
66
-
67
- pipe = pipeline(
68
- task="text-generation",
69
- model=model,
70
- tokenizer=tokenizer,
71
- # device=0, # GPU device number
72
- # max_length=512,
73
- do_sample=True,
74
- top_p=0.95,
75
- top_k=50,
76
- temperature=0.7,
77
- )
78
-
79
- return pipe
80
-
81
-
82
  class OurLLM(CustomLLM):
83
  def __init__(self, model_name: str, model_pipeline):
84
  self.model_name = model_name
@@ -104,10 +82,12 @@ class OurLLM(CustomLLM):
104
  def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
105
  raise NotImplementedError()
106
 
 
107
  class LlamaCustom:
108
  def __init__(self, model_name: str) -> None:
109
  self.vector_index = self.initialize_index(model_name=model_name)
110
 
 
111
  def initialize_index(_self, model_name: str):
112
  index_name = model_name.split("/")[-1]
113
 
 
20
  StorageContext,
21
  load_index_from_storage,
22
  )
 
23
  from llama_index.llms import CompletionResponse, CustomLLM, LLMMetadata
24
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
25
 
 
57
  return pipe
58
 
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  class OurLLM(CustomLLM):
61
  def __init__(self, model_name: str, model_pipeline):
62
  self.model_name = model_name
 
82
  def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
83
  raise NotImplementedError()
84
 
85
+
86
  class LlamaCustom:
87
  def __init__(self, model_name: str) -> None:
88
  self.vector_index = self.initialize_index(model_name=model_name)
89
 
90
+ @st.cache_resource
91
  def initialize_index(_self, model_name: str):
92
  index_name = model_name.split("/")[-1]
93