Nekochu commited on
Commit
04b7ee8
·
verified ·
1 Parent(s): b0cf25c

Update load model

Browse files
Files changed (1) hide show
  1. app.py +11 -6
app.py CHANGED
@@ -13,29 +13,31 @@ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
13
 
14
  DESCRIPTION = """\
15
  # Nekochu/Luminia-13B-v3
16
-
17
- This Space demonstrates model [Nekochu/Luminia-13B-v3](https://huggingface.co/Nekochu/Luminia-13B-v3) by Nekochu, a Llama 2 model with 13B parameters fine-tuned for SD gen prompt
18
  """
19
 
20
  LICENSE = """
21
  <p/>
22
-
23
  ---.
24
  """
25
 
26
  if not torch.cuda.is_available():
27
  DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
28
 
 
29
 
30
- if torch.cuda.is_available():
31
- model_id = "Nekochu/Luminia-13B-v3"
 
32
  model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", load_in_4bit=True)
33
  tokenizer = AutoTokenizer.from_pretrained(model_id)
34
  tokenizer.use_default_system_prompt = False
35
-
 
36
 
37
  @spaces.GPU(duration=120)
38
  def generate(
 
39
  message: str,
40
  chat_history: list[tuple[str, str]],
41
  system_prompt: str,
@@ -45,6 +47,7 @@ def generate(
45
  top_k: int = 50,
46
  repetition_penalty: float = 1.2,
47
  ) -> Iterator[str]:
 
48
  conversation = []
49
  if system_prompt:
50
  conversation.append({"role": "system", "content": system_prompt})
@@ -78,10 +81,12 @@ def generate(
78
  outputs.append(text)
79
  yield "".join(outputs)
80
 
 
81
 
82
  chat_interface = gr.ChatInterface(
83
  fn=generate,
84
  additional_inputs=[
 
85
  gr.Textbox(label="System prompt", lines=6),
86
  gr.Slider(
87
  label="Max new tokens",
 
13
 
14
  DESCRIPTION = """\
15
  # Nekochu/Luminia-13B-v3
16
+ This Space demonstrates model Nekochu/Luminia-13B-v3 by Nekochu, a Llama 2 model with 13B parameters fine-tuned for SD gen prompt
 
17
  """
18
 
19
  LICENSE = """
20
  <p/>
 
21
  ---.
22
  """
23
 
24
  if not torch.cuda.is_available():
25
  DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
26
 
27
+ MODELS = {}
28
 
29
+ def load_model(model_id):
30
+ if model_id in MODELS:
31
+ return MODELS[model_id]
32
  model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", load_in_4bit=True)
33
  tokenizer = AutoTokenizer.from_pretrained(model_id)
34
  tokenizer.use_default_system_prompt = False
35
+ MODELS[model_id] = (model, tokenizer)
36
+ return model, tokenizer
37
 
38
  @spaces.GPU(duration=120)
39
  def generate(
40
+ model_id: str,
41
  message: str,
42
  chat_history: list[tuple[str, str]],
43
  system_prompt: str,
 
47
  top_k: int = 50,
48
  repetition_penalty: float = 1.2,
49
  ) -> Iterator[str]:
50
+ model, tokenizer = load_model(model_id) # Load or retrieve the selected model
51
  conversation = []
52
  if system_prompt:
53
  conversation.append({"role": "system", "content": system_prompt})
 
81
  outputs.append(text)
82
  yield "".join(outputs)
83
 
84
+ MODEL_IDS = ["Nekochu/Luminia-13B-v3", "Nekochu/Llama-2-13B-German-ORPO"] # Add more model ids as needed
85
 
86
  chat_interface = gr.ChatInterface(
87
  fn=generate,
88
  additional_inputs=[
89
+ gr.Dropdown(MODEL_IDS, label="Model ID"), # Add this line
90
  gr.Textbox(label="System prompt", lines=6),
91
  gr.Slider(
92
  label="Max new tokens",