Spaces:
Sleeping
Sleeping
Zwea Htet
commited on
Commit
·
8f095e3
1
Parent(s):
a550aaa
update llama custom
Browse files- models/llamaCustom.py +2 -22
models/llamaCustom.py
CHANGED
@@ -20,7 +20,6 @@ from llama_index import (
|
|
20 |
StorageContext,
|
21 |
load_index_from_storage,
|
22 |
)
|
23 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
24 |
from llama_index.llms import CompletionResponse, CustomLLM, LLMMetadata
|
25 |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
26 |
|
@@ -58,27 +57,6 @@ def load_model(model_name: str):
|
|
58 |
return pipe
|
59 |
|
60 |
|
61 |
-
@st.cache_resource
|
62 |
-
def load_model(mode_name: str):
|
63 |
-
# llm_model_name = "bigscience/bloom-560m"
|
64 |
-
tokenizer = AutoTokenizer.from_pretrained(mode_name)
|
65 |
-
model = AutoModelForCausalLM.from_pretrained(mode_name, config="T5Config")
|
66 |
-
|
67 |
-
pipe = pipeline(
|
68 |
-
task="text-generation",
|
69 |
-
model=model,
|
70 |
-
tokenizer=tokenizer,
|
71 |
-
# device=0, # GPU device number
|
72 |
-
# max_length=512,
|
73 |
-
do_sample=True,
|
74 |
-
top_p=0.95,
|
75 |
-
top_k=50,
|
76 |
-
temperature=0.7,
|
77 |
-
)
|
78 |
-
|
79 |
-
return pipe
|
80 |
-
|
81 |
-
|
82 |
class OurLLM(CustomLLM):
|
83 |
def __init__(self, model_name: str, model_pipeline):
|
84 |
self.model_name = model_name
|
@@ -104,10 +82,12 @@ class OurLLM(CustomLLM):
|
|
104 |
def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
|
105 |
raise NotImplementedError()
|
106 |
|
|
|
107 |
class LlamaCustom:
|
108 |
def __init__(self, model_name: str) -> None:
|
109 |
self.vector_index = self.initialize_index(model_name=model_name)
|
110 |
|
|
|
111 |
def initialize_index(_self, model_name: str):
|
112 |
index_name = model_name.split("/")[-1]
|
113 |
|
|
|
20 |
StorageContext,
|
21 |
load_index_from_storage,
|
22 |
)
|
|
|
23 |
from llama_index.llms import CompletionResponse, CustomLLM, LLMMetadata
|
24 |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
25 |
|
|
|
57 |
return pipe
|
58 |
|
59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
class OurLLM(CustomLLM):
|
61 |
def __init__(self, model_name: str, model_pipeline):
|
62 |
self.model_name = model_name
|
|
|
82 |
def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
|
83 |
raise NotImplementedError()
|
84 |
|
85 |
+
|
86 |
class LlamaCustom:
|
87 |
def __init__(self, model_name: str) -> None:
|
88 |
self.vector_index = self.initialize_index(model_name=model_name)
|
89 |
|
90 |
+
@st.cache_resource
|
91 |
def initialize_index(_self, model_name: str):
|
92 |
index_name = model_name.split("/")[-1]
|
93 |
|