Zai
test
06db6e9
from g2p.model import Model
from g2p.box_utils import centers_to_extents
from g2p.retrieval import DataRetriever
from g2p.floorplan import FloorPlan
from g2p.align import align_fp_refine
from g2p.add_archs import add_door_window
import numpy as np
import pickle
import torch
class App():
def __init__(self,model_path,device='cpu', data_path=None,tf_path=None,centroid_path=None,cluster_path=None):
super().__init__()
self.load_model(model_path,device=device)
if data_path is not None: self.load_database(data_path,tf_path,centroid_path,cluster_path)
def load_model(self,model_path,device='cpu'):
model = Model()
model.load_state_dict(torch.load(model_path))
model.to(device)
model.eval()
self.model = model
self.device = device
def load_database(self,data_path,tf_path,centroid_path,cluster_path):
assert tf_path is not None
assert centroid_path is not None
assert cluster_path is not None
self.data = pickle.load(open(data_path,'rb'))['data']
tf_train = np.load(tf_path)
centroids = np.load(centroid_path)
clusters = np.load(cluster_path)
self.retriever = DataRetriever(tf_train,centroids,clusters)
def retrieve(self,data_query,k=10,multi_clusters=False):
index = self.retriever.retrieve_cluster(data_query,k=k,multi_clusters=multi_clusters)
data = self.data[index]
return data
def transfer(self,data_boundary,data_graph):
fp_boundary = FloorPlan(data_boundary)
fp_graph = FloorPlan(data_graph)
fp_transfer = fp_boundary.adapt_graph(fp_graph)
fp_transfer.adjust_graph()
fp_transfer.data.rType = fp_transfer.get_rooms(tensor=False)
fp_transfer.data.rEdge = fp_transfer.get_triples(tensor=False)[:, [0, 2, 1]]
return fp_transfer.data
def forward(self,data,network_data=False):
if network_data:
data.box = np.concatenate([data.gtBoxNew,data.rType.reshape(-1,1)],axis=-1)
data.edge = data.rEdge
fp = FloorPlan(data)
boundary, inside_box, rooms, attrs, triples = fp.get_test_data()
boundary = boundary.unsqueeze(0).to(self.device)
inside_box = inside_box.to(self.device)
rooms = rooms.to(self.device)
attrs = attrs.to(self.device)
triples = triples.to(self.device)
with torch.no_grad():
model_out = self.model(
rooms,
triples,
boundary,
obj_to_img = None,
attributes = attrs,
boxes_gt= None,
generate = True,
refine = True,
relative = True,
inside_box=inside_box
)
boxes_pred, gene_layout, boxes_refine= model_out
boxes_pred = centers_to_extents(boxes_pred)*255
boxes_pred = boxes_pred.squeeze().cpu().numpy().astype(int)
boxes_refine = centers_to_extents(boxes_refine)*255
boxes_refine = boxes_refine.squeeze().cpu().numpy().astype(int)
#gene_layout = gene_layout*boundary[:,:1]
gene_preds = torch.argmax(gene_layout.softmax(1).detach(),dim=1).squeeze()
gene_preds[boundary[0,0]==0]=13
gene_preds = gene_preds.cpu().numpy().astype(int)
fp.data.predBox = boxes_pred
fp.data.refineBox = boxes_refine
fp.data.gene = gene_preds
return fp.data
def align(self,data):
boxes_aligned, order, room_boundaries = align_fp_refine(
data.boundary,
data.refineBox,
data.rType,
data.rEdge,
data.gene
)
data.newBox = boxes_aligned
data.order = order
data.rBoundary = room_boundaries
return data
def decorate(self,data):
doors,windows = add_door_window(data)
data.doors = doors
data.windows = windows
return data
def generate(self,data_boundary):
data = self.retrieve(data_boundary)[0]
data = self.transfer(data_boundary,data)
data = self.forward(data)
data = self.align(data)
data = self.decorate(data)
return data