|
import os |
|
import sys |
|
import torch |
|
import gradio as gr |
|
from huggingface_hub import hf_hub_download, snapshot_download |
|
|
|
def predict(input_text: str) -> str: |
|
""" |
|
Memproses input dan menghasilkan prediksi |
|
""" |
|
try: |
|
|
|
values = [float(x.strip()) for x in input_text.split(",")] |
|
if len(values) != 5: |
|
return f"Error: Masukkan tepat 5 nilai (dipisahkan koma). Anda memasukkan {len(values)} nilai." |
|
|
|
|
|
repo_path = snapshot_download( |
|
repo_id="VLabTech/cognitive_net", |
|
local_dir="./model_repo" |
|
) |
|
|
|
|
|
import sys |
|
sys.path.append("./model_repo") |
|
|
|
|
|
from memory import CognitiveMemory |
|
from node import CognitiveNode |
|
from network import DynamicCognitiveNet |
|
|
|
|
|
model = DynamicCognitiveNet(input_size=5, output_size=2) |
|
|
|
|
|
checkpoint_path = hf_hub_download( |
|
repo_id="VLabTech/cognitive_net", |
|
filename="model.pt", |
|
local_dir="./model_weights" |
|
) |
|
model.load_state_dict(torch.load(checkpoint_path)) |
|
model.eval() |
|
|
|
|
|
input_tensor = torch.tensor(values, dtype=torch.float32) |
|
with torch.no_grad(): |
|
output = model(input_tensor) |
|
|
|
|
|
result = "Hasil Prediksi:\n" |
|
result += f"Output 1: {output[0]:.4f}\n" |
|
result += f"Output 2: {output[1]:.4f}" |
|
|
|
return result |
|
|
|
except ValueError as e: |
|
return f"Error dalam format input: {str(e)}" |
|
except Exception as e: |
|
return f"Error: {str(e)}" |
|
|
|
|
|
demo = gr.Interface( |
|
fn=predict, |
|
inputs=gr.Textbox( |
|
label="Input Values", |
|
placeholder="Masukkan 5 nilai numerik (pisahkan dengan koma). Contoh: 1.0, 2.0, 3.0, 4.0, 5.0" |
|
), |
|
outputs=gr.Textbox(label="Hasil Prediksi"), |
|
title="Cognitive Network Demo", |
|
description=""" |
|
## Cognitive Network Inference Demo |
|
Model ini menerima 5 input numerik dan menghasilkan 2 output numerik menggunakan |
|
arsitektur Cognitive Network yang terinspirasi dari cara kerja otak biologis. |
|
""", |
|
examples=[ |
|
["1.0, 2.0, 3.0, 4.0, 5.0"], |
|
["0.5, -1.0, 2.5, 1.5, -0.5"], |
|
["0.1, 0.2, 0.3, 0.4, 0.5"] |
|
] |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |