Spaces:
Sleeping
Sleeping
File size: 4,478 Bytes
92d7f1e d28e9af 2ddd1e5 62e947b d28e9af 62e947b d28e9af 62e947b d28e9af 62e947b 92d7f1e d28e9af 92d7f1e 2ddd1e5 d28e9af 92d7f1e d28e9af 2ddd1e5 92d7f1e d28e9af 92d7f1e 8b5c657 d28e9af 2ddd1e5 d28e9af 8b5c657 d28e9af 8b5c657 d28e9af 2ddd1e5 d28e9af 92d7f1e d28e9af 8b5c657 92d7f1e d28e9af cd81b99 37a110b 62e947b 37a110b d28e9af 37a110b 92d7f1e 52e0217 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import streamlit as st
import torch
from transformers import AutoModelForCausalLM
import difflib
import requests
import os
import json
FIREBASE_URL = os.getenv("FIREBASE_URL")
def fetch_from_firebase(model_id):
response = requests.get(f"{FIREBASE_URL}/model_structures/{model_id}.json")
if response.status_code == 200:
return response.json()
return None
def save_to_firebase(model_id, structure):
response = requests.put(
f"{FIREBASE_URL}/model_structures/{model_id}.json", data=json.dumps(structure)
)
return response.status_code == 200
def get_model_structure(model_id) -> list[str]:
struct_lines = fetch_from_firebase(model_id)
if struct_lines:
return struct_lines
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="cpu",
)
structure = {k: str(v.shape) for k, v in model.state_dict().items()}
struct_lines = [f"{k}: {v}" for k, v in structure.items()]
save_to_firebase(model_id, struct_lines)
return struct_lines
def compare_structures(struct1_lines: list[str], struct2_lines: list[str]):
# struct1_lines = [f"{k}: {v}" for k, v in struct1.items()]
# struct2_lines = [f"{k}: {v}" for k, v in struct2.items()]
diff = difflib.ndiff(struct1_lines, struct2_lines)
return diff
def display_diff(diff):
left_lines = []
right_lines = []
diff_found = False
for line in diff:
if line.startswith("- "):
left_lines.append(
f'<span style="background-color: #ffdddd;">{line[2:]}</span>'
)
right_lines.append("")
diff_found = True
elif line.startswith("+ "):
right_lines.append(
f'<span style="background-color: #ddffdd;">{line[2:]}</span>'
)
left_lines.append("")
diff_found = True
elif line.startswith(" "):
left_lines.append(line[2:])
right_lines.append(line[2:])
else:
pass
left_html = "<br>".join(left_lines)
right_html = "<br>".join(right_lines)
return left_html, right_html, diff_found
# Set Streamlit page configuration to wide mode
st.set_page_config(layout="wide")
# Apply custom CSS for wider layout
st.markdown(
"""
<style>
.reportview-container .main .block-container {
max-width: 100%;
padding-left: 10%;
padding-right: 10%;
}
.stMarkdown {
white-space: pre-wrap;
}
</style>
""",
unsafe_allow_html=True,
)
st.title("Model Structure Comparison Tool")
model_id1 = st.text_input("Enter the first HuggingFace Model ID")
model_id2 = st.text_input("Enter the second HuggingFace Model ID")
if "compare_button_clicked" not in st.session_state:
st.session_state.compare_button_clicked = False
if st.session_state.compare_button_clicked:
with st.spinner('Comparing models and loading tokenizers...'):
if model_id1 and model_id2:
struct1 = get_model_structure(model_id1)
struct2 = get_model_structure(model_id2)
diff = compare_structures(struct1, struct2)
left_html, right_html, diff_found = display_diff(diff)
st.write("### Comparison Result")
if not diff_found:
st.success("The model structures are identical.")
col1, col2 = st.columns([1.5, 1.5]) # Adjust the ratio to make columns wider
with col1:
st.write("### Model 1")
st.markdown(left_html, unsafe_allow_html=True)
with col2:
st.write("### Model 2")
st.markdown(right_html, unsafe_allow_html=True)
# Tokenizer verification
try:
tokenizer1 = AutoTokenizer.from_pretrained(model_id1)
tokenizer2 = AutoTokenizer.from_pretrained(model_id2)
st.write(f"**{model_id1} Tokenizer Vocab Size**: {tokenizer1.vocab_size}")
st.write(f"**{model_id2} Tokenizer Vocab Size**: {tokenizer2.vocab_size}")
except Exception as e:
st.error(f"Error loading tokenizers: {e}")
else:
st.error("Please enter both model IDs.")
st.session_state.compare_button_clicked = False
else:
if st.button("Compare Models"):
st.session_state.compare_button_clicked = True
|