yunzi7 commited on
Commit
387a8a0
ยท
1 Parent(s): 73a7687

load model

Browse files
ice_breaking_challenge/__init__.py CHANGED
@@ -2,10 +2,13 @@ import os
2
 
3
  from flask import Flask, session
4
  from flask_session import Session
 
5
 
 
6
 
7
  def create_app(test_config=None):
8
  """Create and configure an instance of the Flask application."""
 
9
  app = Flask(__name__, instance_relative_config=True)
10
  app.config.from_mapping(
11
  # a default secret that should be overridden by instance config
@@ -16,6 +19,9 @@ def create_app(test_config=None):
16
  app.config['SESSION_TYPE'] = 'filesystem'
17
  Session(app)
18
 
 
 
 
19
  if test_config is None:
20
  # load the instance config, if it exists, when not testing
21
  app.config.from_pyfile("config.py", silent=True)
 
2
 
3
  from flask import Flask, session
4
  from flask_session import Session
5
+ from ice_breaking_challenge.models.model_loader import load_model_with_lora
6
 
7
+ model = None
8
 
9
  def create_app(test_config=None):
10
  """Create and configure an instance of the Flask application."""
11
+ global model
12
  app = Flask(__name__, instance_relative_config=True)
13
  app.config.from_mapping(
14
  # a default secret that should be overridden by instance config
 
19
  app.config['SESSION_TYPE'] = 'filesystem'
20
  Session(app)
21
 
22
+ # model loading
23
+ model = load_model_with_lora()
24
+
25
  if test_config is None:
26
  # load the instance config, if it exists, when not testing
27
  app.config.from_pyfile("config.py", silent=True)
ice_breaking_challenge/models/model_loader.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import keras_nlp
2
+
3
+ MODEL_NAME = "gemma2_instruct_2b_en"
4
+ LORA_WEIGHT_PATH = "weights/gemma2_it_2b_icebreaking.lora.h5"
5
+
6
+ def load_model_with_lora(model_name:str = MODEL_NAME, lora_weight_path: str = LORA_WEIGHT_PATH):
7
+ """
8
+ Keras ๊ธฐ๋ฐ˜ ๋ชจ๋ธ ๋กœ๋“œ ๋ฐ LoRA ๊ฐ€์ค‘์น˜ ์ ์šฉ
9
+
10
+ Args:
11
+ model_name (str): ๋กœ๋“œํ•  ๋ชจ๋ธ์˜ ์ด๋ฆ„
12
+ lora_weight_path (str): ์ ์šฉํ•  LoRA ๊ฐ€์ค‘์น˜ ํŒŒ์ผ์˜ ๊ฒฝ๋กœ
13
+
14
+ Returns:
15
+ keras_nlp.models.GemmaCausalLM: ๋กœ๋“œ๋œ ๋ชจ๋ธ
16
+ """
17
+ model = keras_nlp.models.GemmaCausalLM.from_preset(model_name)
18
+
19
+ model.backbone.load_lora_weights(lora_weight_path)
20
+
21
+ return model