Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
from transformers import AutoModelForCausalLM | |
import difflib | |
import requests | |
import os | |
import json | |
FIREBASE_URL = os.getenv("FIREBASE_URL") | |
def fetch_from_firebase(model_id): | |
response = requests.get(f"{FIREBASE_URL}/model_structures/{model_id}.json") | |
if response.status_code == 200: | |
return response.json() | |
return None | |
def save_to_firebase(model_id, structure): | |
response = requests.put( | |
f"{FIREBASE_URL}/model_structures/{model_id}.json", data=json.dumps(structure) | |
) | |
return response.status_code == 200 | |
def get_model_structure(model_id) -> list[str]: | |
struct_lines = fetch_from_firebase(model_id) | |
if struct_lines: | |
return struct_lines | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
torch_dtype=torch.bfloat16, | |
device_map="cpu", | |
) | |
structure = {k: str(v.shape) for k, v in model.state_dict().items()} | |
struct_lines = [f"{k}: {v}" for k, v in structure.items()] | |
save_to_firebase(model_id, struct_lines) | |
return struct_lines | |
def compare_structures(struct1_lines: list[str], struct2_lines: list[str]): | |
# struct1_lines = [f"{k}: {v}" for k, v in struct1.items()] | |
# struct2_lines = [f"{k}: {v}" for k, v in struct2.items()] | |
diff = difflib.ndiff(struct1_lines, struct2_lines) | |
return diff | |
def display_diff(diff): | |
left_lines = [] | |
right_lines = [] | |
diff_found = False | |
for line in diff: | |
if line.startswith("- "): | |
left_lines.append( | |
f'<span style="background-color: #ffdddd;">{line[2:]}</span>' | |
) | |
right_lines.append("") | |
diff_found = True | |
elif line.startswith("+ "): | |
right_lines.append( | |
f'<span style="background-color: #ddffdd;">{line[2:]}</span>' | |
) | |
left_lines.append("") | |
diff_found = True | |
elif line.startswith(" "): | |
left_lines.append(line[2:]) | |
right_lines.append(line[2:]) | |
else: | |
pass | |
left_html = "<br>".join(left_lines) | |
right_html = "<br>".join(right_lines) | |
return left_html, right_html, diff_found | |
# Set Streamlit page configuration to wide mode | |
st.set_page_config(layout="wide") | |
# Apply custom CSS for wider layout | |
st.markdown( | |
""" | |
<style> | |
.reportview-container .main .block-container { | |
max-width: 100%; | |
padding-left: 10%; | |
padding-right: 10%; | |
} | |
.stMarkdown { | |
white-space: pre-wrap; | |
} | |
</style> | |
""", | |
unsafe_allow_html=True, | |
) | |
st.title("Model Structure Comparison Tool") | |
model_id1 = st.text_input("Enter the first HuggingFace Model ID") | |
model_id2 = st.text_input("Enter the second HuggingFace Model ID") | |
if model_id1 and model_id2: | |
struct1 = get_model_structure(model_id1) | |
struct2 = get_model_structure(model_id2) | |
diff = compare_structures(struct1, struct2) | |
left_html, right_html, diff_found = display_diff(diff) | |
st.write("### Comparison Result") | |
if not diff_found: | |
st.success("The model structures are identical.") | |
col1, col2 = st.columns([1.5, 1.5]) # Adjust the ratio to make columns wider | |
with col1: | |
st.write("### Model 1") | |
st.markdown(left_html, unsafe_allow_html=True) | |
with col2: | |
st.write("### Model 2") | |
st.markdown(right_html, unsafe_allow_html=True) | |