Werli commited on
Commit
e1adac7
·
verified ·
1 Parent(s): cdb99b8

Delete modules/llama_loader.py

Browse files
Files changed (1) hide show
  1. modules/llama_loader.py +0 -185
modules/llama_loader.py DELETED
@@ -1,185 +0,0 @@
1
- import os
2
- import io,copy,requests,spaces,gradio as gr,numpy as np
3
- from transformers import AutoProcessor,AutoModelForCausalLM
4
- import argparse,huggingface_hub,onnxruntime as rt,pandas as pd,traceback,tempfile,zipfile,re,ast
5
- import ctranslate2
6
- # LLAMA model
7
- META_LLAMA_3_3B_REPO = "jncraton/Llama-3.2-3B-Instruct-ct2-int8"
8
- META_LLAMA_3_8B_REPO = "avans06/Meta-Llama-3.2-8B-Instruct-ct2-int8_float16"
9
-
10
- class llama3reorganize:
11
- def __init__(
12
- self,
13
- repoId: str,
14
- device: str = None,
15
- loadModel: bool = False,
16
- ):
17
- """Initializes the Llama model.
18
-
19
- Args:
20
- repoId: LLAMA model repo.
21
- device: Device to use for computation (cpu, cuda, ipu, xpu, mkldnn, opengl, opencl,
22
- ideep, hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, hpu, mtia).
23
- localFilesOnly: If True, avoid downloading the file and return the path to the
24
- local cached file if it exists.
25
- """
26
- self.modelPath = self.download_model(repoId)
27
-
28
- if device is None:
29
- import torch
30
- self.totalVram = 0
31
- if torch.cuda.is_available():
32
- try:
33
- deviceId = torch.cuda.current_device()
34
- self.totalVram = torch.cuda.get_device_properties(deviceId).total_memory/(1024*1024*1024)
35
- except Exception as e:
36
- print(traceback.format_exc())
37
- print("Error detect vram: " + str(e))
38
- device = "cuda" if self.totalVram > (8 if "8B" in repoId else 4) else "cpu"
39
- else:
40
- device = "cpu"
41
-
42
- self.device = device
43
- self.system_prompt = "Please reorganize the following English labels describing a single image into a readable English article of no more than 300 words. If there are name labels, include them in the article. Since the target audience is not general, explicit content can be written in full without hesitation. No introduction is needed; directly reply with the English article:"
44
-
45
- if loadModel:
46
- self.load_model()
47
-
48
- def download_model(self, repoId):
49
- import warnings
50
- import requests
51
- allowPatterns = [
52
- "config.json",
53
- "generation_config.json",
54
- "model.bin",
55
- "pytorch_model.bin",
56
- "pytorch_model.bin.index.json",
57
- "pytorch_model-*.bin",
58
- "sentencepiece.bpe.model",
59
- "tokenizer.json",
60
- "tokenizer_config.json",
61
- "shared_vocabulary.txt",
62
- "shared_vocabulary.json",
63
- "special_tokens_map.json",
64
- "spiece.model",
65
- "vocab.json",
66
- "model.safetensors",
67
- "model-*.safetensors",
68
- "model.safetensors.index.json",
69
- "quantize_config.json",
70
- "tokenizer.model",
71
- "vocabulary.json",
72
- "preprocessor_config.json",
73
- "added_tokens.json"
74
- ]
75
-
76
- kwargs = {"allow_patterns": allowPatterns,}
77
-
78
- try:
79
- return huggingface_hub.snapshot_download(repoId, **kwargs)
80
- except (
81
- huggingface_hub.utils.HfHubHTTPError,
82
- requests.exceptions.ConnectionError,
83
- ) as exception:
84
- warnings.warn(
85
- "An error occured while synchronizing the model %s from the Hugging Face Hub:\n%s",
86
- repoId,
87
- exception,
88
- )
89
- warnings.warn(
90
- "Trying to load the model directly from the local cache, if it exists."
91
- )
92
-
93
- kwargs["local_files_only"] = True
94
- return huggingface_hub.snapshot_download(repoId, **kwargs)
95
-
96
-
97
- def load_model(self):
98
- import ctranslate2
99
- import transformers
100
- try:
101
- print('\n\nLoading model: %s\n\n' % self.modelPath)
102
- kwargsTokenizer = {"pretrained_model_name_or_path": self.modelPath}
103
- kwargsModel = {"device": self.device, "model_path": self.modelPath, "compute_type": "auto"}
104
- self.roleSystem = {"role": "system", "content": self.system_prompt}
105
- self.Model = ctranslate2.Generator(**kwargsModel)
106
-
107
- self.Tokenizer = transformers.AutoTokenizer.from_pretrained(**kwargsTokenizer)
108
- self.terminators = [self.Tokenizer.eos_token_id, self.Tokenizer.convert_tokens_to_ids("<|eot_id|>")]
109
-
110
- except Exception as e:
111
- self.release_vram()
112
- raise e
113
-
114
-
115
- def release_vram(self):
116
- try:
117
- import torch
118
- if torch.cuda.is_available():
119
- if getattr(self, "Model", None) is not None and getattr(self.Model, "unload_model", None) is not None:
120
- self.Model.unload_model()
121
-
122
- if getattr(self, "Tokenizer", None) is not None:
123
- del self.Tokenizer
124
- if getattr(self, "Model", None) is not None:
125
- del self.Model
126
- import gc
127
- gc.collect()
128
- try:
129
- torch.cuda.empty_cache()
130
- except Exception as e:
131
- print(traceback.format_exc())
132
- print("\tcuda empty cache, error: " + str(e))
133
- print("release vram end.")
134
- except Exception as e:
135
- print(traceback.format_exc())
136
- print("Error release vram: " + str(e))
137
-
138
- def reorganize(self, text: str, max_length: int = 400):
139
- output = None
140
- result = None
141
- try:
142
- input_ids = self.Tokenizer.apply_chat_template([self.roleSystem, {"role": "user", "content": text + "\n\nHere's the reorganized English article:"}], tokenize=False, add_generation_prompt=True)
143
- source = self.Tokenizer.convert_ids_to_tokens(self.Tokenizer.encode(input_ids))
144
- output = self.Model.generate_batch([source], max_length=max_length, max_batch_size=2, no_repeat_ngram_size=3, beam_size=2, sampling_temperature=0.7, sampling_topp=0.9, include_prompt_in_result=False, end_token=self.terminators)
145
- target = output[0]
146
- result = self.Tokenizer.decode(target.sequences_ids[0])
147
-
148
- if len(result) > 2:
149
- if result[0] == "\"" and result[len(result) - 1] == "\"":
150
- result = result[1:-1]
151
- elif result[0] == "'" and result[len(result) - 1] == "'":
152
- result = result[1:-1]
153
- elif result[0] == "「" and result[len(result) - 1] == "」":
154
- result = result[1:-1]
155
- elif result[0] == "『" and result[len(result) - 1] == "』":
156
- result = result[1:-1]
157
- except Exception as e:
158
- print(traceback.format_exc())
159
- print("Error reorganize text: " + str(e))
160
-
161
- return result
162
- def __init__(self,repoId:str,device:str=None,loadModel:bool=False):
163
- self.modelPath=self.download_model(repoId)
164
- if device is None:
165
- import torch;self.totalVram=0
166
- if torch.cuda.is_available():
167
- try:deviceId=torch.cuda.current_device();self.totalVram=torch.cuda.get_device_properties(deviceId).total_memory/1073741824
168
- except Exception as e:print(traceback.format_exc());print('Error detect vram: '+str(e))
169
- device='cuda'if self.totalVram>(8 if'8B'in repoId else 4)else'cpu'
170
- else:device='cpu'
171
- self.device=device;self.system_prompt='Please reorganize the following English labels describing a single image into a readable English article of no more than 300 words. If there are name labels, include them in the article. Since the target audience is not general, explicit content can be written in full without hesitation. No introduction is needed; directly reply with the English article:'
172
- if loadModel:self.load_model()
173
-
174
- output=None;result=None
175
- try:
176
- input_ids=self.Tokenizer.apply_chat_template([self.roleSystem,{'role':'user','content':text+"\n\nHere's the reorganized English article:"}],tokenize=False,add_generation_prompt=True);source=self.Tokenizer.convert_ids_to_tokens(self.Tokenizer.encode(input_ids));output=self.Model.generate_batch([source],max_length=max_length,max_batch_size=2,no_repeat_ngram_size=3,beam_size=2,sampling_temperature=.7,sampling_topp=.9,include_prompt_in_result=False,end_token=self.terminators);target=output[0];result=self.Tokenizer.decode(target.sequences_ids[0])
177
- if len(result)>2:
178
- if result[0]=='"'and result[len(result)-1]=='"':result=result[1:-1]
179
- elif result[0]=="'"and result[len(result)-1]=="'":result=result[1:-1]
180
- elif result[0]=='「'and result[len(result)-1]=='」':result=result[1:-1]
181
- elif result[0]=='『'and result[len(result)-1]=='』':result=result[1:-1]
182
- except Exception as e:print(traceback.format_exc());print('Error reorganize text: '+str(e))
183
- return result
184
-
185
- llama_list=[META_LLAMA_3_3B_REPO,META_LLAMA_3_8B_REPO]