Spaces:
Runtime error
Runtime error
load model
Browse files
ice_breaking_challenge/__init__.py
CHANGED
@@ -2,10 +2,13 @@ import os
|
|
2 |
|
3 |
from flask import Flask, session
|
4 |
from flask_session import Session
|
|
|
5 |
|
|
|
6 |
|
7 |
def create_app(test_config=None):
|
8 |
"""Create and configure an instance of the Flask application."""
|
|
|
9 |
app = Flask(__name__, instance_relative_config=True)
|
10 |
app.config.from_mapping(
|
11 |
# a default secret that should be overridden by instance config
|
@@ -16,6 +19,9 @@ def create_app(test_config=None):
|
|
16 |
app.config['SESSION_TYPE'] = 'filesystem'
|
17 |
Session(app)
|
18 |
|
|
|
|
|
|
|
19 |
if test_config is None:
|
20 |
# load the instance config, if it exists, when not testing
|
21 |
app.config.from_pyfile("config.py", silent=True)
|
|
|
2 |
|
3 |
from flask import Flask, session
|
4 |
from flask_session import Session
|
5 |
+
from ice_breaking_challenge.models.model_loader import load_model_with_lora
|
6 |
|
7 |
+
model = None
|
8 |
|
9 |
def create_app(test_config=None):
|
10 |
"""Create and configure an instance of the Flask application."""
|
11 |
+
global model
|
12 |
app = Flask(__name__, instance_relative_config=True)
|
13 |
app.config.from_mapping(
|
14 |
# a default secret that should be overridden by instance config
|
|
|
19 |
app.config['SESSION_TYPE'] = 'filesystem'
|
20 |
Session(app)
|
21 |
|
22 |
+
# model loading
|
23 |
+
model = load_model_with_lora()
|
24 |
+
|
25 |
if test_config is None:
|
26 |
# load the instance config, if it exists, when not testing
|
27 |
app.config.from_pyfile("config.py", silent=True)
|
ice_breaking_challenge/models/model_loader.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import keras_nlp
|
2 |
+
|
3 |
+
MODEL_NAME = "gemma2_instruct_2b_en"
|
4 |
+
LORA_WEIGHT_PATH = "weights/gemma2_it_2b_icebreaking.lora.h5"
|
5 |
+
|
6 |
+
def load_model_with_lora(model_name:str = MODEL_NAME, lora_weight_path: str = LORA_WEIGHT_PATH):
|
7 |
+
"""
|
8 |
+
Keras ๊ธฐ๋ฐ ๋ชจ๋ธ ๋ก๋ ๋ฐ LoRA ๊ฐ์ค์น ์ ์ฉ
|
9 |
+
|
10 |
+
Args:
|
11 |
+
model_name (str): ๋ก๋ํ ๋ชจ๋ธ์ ์ด๋ฆ
|
12 |
+
lora_weight_path (str): ์ ์ฉํ LoRA ๊ฐ์ค์น ํ์ผ์ ๊ฒฝ๋ก
|
13 |
+
|
14 |
+
Returns:
|
15 |
+
keras_nlp.models.GemmaCausalLM: ๋ก๋๋ ๋ชจ๋ธ
|
16 |
+
"""
|
17 |
+
model = keras_nlp.models.GemmaCausalLM.from_preset(model_name)
|
18 |
+
|
19 |
+
model.backbone.load_lora_weights(lora_weight_path)
|
20 |
+
|
21 |
+
return model
|