|
import torch |
|
from transformers import XLMRobertaTokenizer, XLMRobertaForSequenceClassification |
|
import streamlit as st |
|
|
|
|
|
model_path = "fine_tuned_xlm_roberta" |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
tokenizer = XLMRobertaTokenizer.from_pretrained(model_path) |
|
model = XLMRobertaForSequenceClassification.from_pretrained(model_path) |
|
model.to(device) |
|
model.eval() |
|
|
|
|
|
def classify_text(text, max_length=128): |
|
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=max_length) |
|
inputs = {key: val.to(device) for key, val in inputs.items()} |
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1) |
|
pred_label = torch.argmax(probabilities, dim=-1).item() |
|
confidence = probabilities[0, pred_label].item() |
|
|
|
return "Kyrgyz" if pred_label == 1 else "Non-Kyrgyz", confidence |
|
|
|
|
|
st.title("Kyrgyz Language Classifier") |
|
st.write("This tool identifies whether the given text is Kyrgyz or not.") |
|
|
|
|
|
st.markdown(""" |
|
**Instructions:** |
|
|
|
* Please enter a **sentence** for better accuracy. |
|
* **Note:** The word "**Салам**" might be classified as Non-Kyrgyz. This is a known exception. |
|
""") |
|
user_input = st.text_area("Enter text to classify:", placeholder="Type your sentence here...") |
|
|
|
if st.button("Classify"): |
|
if user_input.strip(): |
|
label, confidence = classify_text(user_input) |
|
st.write(f"Prediction: **{label}**") |
|
st.write(f"Confidence: **{confidence:.2%}**") |
|
else: |
|
st.warning("Please enter some text for classification.") |