waleko commited on
Commit
ff6815d
·
1 Parent(s): 7bb8428

add new model

Browse files
Files changed (1) hide show
  1. webui.py +10 -7
webui.py CHANGED
@@ -25,8 +25,9 @@ from infer import TikzDocument, TikzGenerator
25
 
26
  # assets = files(__package__) / "assets" if __package__ else files("assets") / "."
27
  models = {
28
- "pix2tikz": "pix2tikz/mixed_e362_step201.pth",
29
- "llava-1.5-7b-hf": "waleko/TikZ-llava-1.5-7b"
 
30
  }
31
 
32
 
@@ -35,12 +36,13 @@ def is_quantization(model_name):
35
 
36
 
37
  @lru_cache(maxsize=1)
38
- def cached_load(model_name, **kwargs) -> ImageToTextPipeline:
 
39
  gr.Info("Instantiating model. Could take a while...") # type: ignore
40
  if not is_quantization(model_name):
41
- return pipeline("image-to-text", model=model_name, **kwargs)
42
  else:
43
- model = AutoModelForPreTraining.from_pretrained(model_name, load_in_4bit=True, **kwargs)
44
  processor = AutoProcessor.from_pretrained(model_name)
45
  return pipeline(task="image-to-text", model=model, tokenizer=processor.tokenizer, image_processor=processor.image_processor)
46
 
@@ -76,7 +78,7 @@ def pix2tikz(
76
 
77
 
78
  def inference(
79
- model_name: str,
80
  image_dict: dict,
81
  temperature: float,
82
  top_p: float,
@@ -85,12 +87,13 @@ def inference(
85
  ):
86
  try:
87
  image = image_dict['composite']
 
88
  if "pix2tikz" in model_name:
89
  yield pix2tikz(model_name, image, temperature, top_p, top_k, expand_to_square)
90
  return
91
 
92
  generate = TikzGenerator(
93
- cached_load(model_name, device_map="auto"),
94
  temperature=temperature,
95
  top_p=top_p,
96
  top_k=top_k,
 
25
 
26
  # assets = files(__package__) / "assets" if __package__ else files("assets") / "."
27
  models = {
28
+ "pix2tikz": {"model": "pix2tikz/mixed_e362_step201.pth"},
29
+ "llava-1.5-7b-hf": {"model": "waleko/TikZ-llava-1.5-7b"},
30
+ "new llava-1.5-7b-hf": {"model": "waleko/TikZ-llava-1.5-7b", "revision": "v2"},
31
  }
32
 
33
 
 
36
 
37
 
38
  @lru_cache(maxsize=1)
39
+ def cached_load(model_dict, **kwargs) -> ImageToTextPipeline:
40
+ model_name = model_dict["model"]
41
  gr.Info("Instantiating model. Could take a while...") # type: ignore
42
  if not is_quantization(model_name):
43
+ return pipeline("image-to-text", **model_dict, **kwargs)
44
  else:
45
+ model = AutoModelForPreTraining.from_pretrained(model_name, load_in_4bit=True, revision=model_dict.get("revision", "main"), **kwargs)
46
  processor = AutoProcessor.from_pretrained(model_name)
47
  return pipeline(task="image-to-text", model=model, tokenizer=processor.tokenizer, image_processor=processor.image_processor)
48
 
 
78
 
79
 
80
  def inference(
81
+ model_dict: dict,
82
  image_dict: dict,
83
  temperature: float,
84
  top_p: float,
 
87
  ):
88
  try:
89
  image = image_dict['composite']
90
+ model_name = model_dict["model"]
91
  if "pix2tikz" in model_name:
92
  yield pix2tikz(model_name, image, temperature, top_p, top_k, expand_to_square)
93
  return
94
 
95
  generate = TikzGenerator(
96
+ cached_load(model_dict, device_map="auto"),
97
  temperature=temperature,
98
  top_p=top_p,
99
  top_k=top_k,