max-long commited on
Commit
1dc581a
·
verified ·
1 Parent(s): b0ad249

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -0
app.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from transformers import AutoModelForTokenClassification, AutoTokenizer, pipeline
3
+ from datasets import load_dataset
4
+ import gradio as gr
5
+
6
+ # Load the dataset with streaming
7
+ dataset = load_dataset("TheBritishLibrary/blbooks", split="train", streaming=True)
8
+
9
+ # Convert streaming dataset to an iterable
10
+ dataset_iter = iter(dataset)
11
+
12
+ # Load tokenizer and model
13
+ model_name = "max-long/textile_machines_3_oct" # Replace with your model's name
14
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
15
+ model = AutoModelForTokenClassification.from_pretrained(model_name)
16
+
17
+ # Initialize NER pipeline
18
+ ner_pipeline = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="simple")
19
+
20
+ def get_random_snippet(stream_iter, tokenizer, max_tokens=350, max_attempts=1000):
21
+ for _ in range(max_attempts):
22
+ try:
23
+ sample = next(stream_iter)['text']
24
+ tokens = tokenizer.tokenize(sample)
25
+ if len(tokens) <= max_tokens:
26
+ return sample
27
+ except StopIteration:
28
+ break
29
+ return "No suitable snippet found."
30
+
31
+ def extract_textile_machinery_entities(text):
32
+ ner_results = ner_pipeline(text)
33
+ textile_entities = [ent for ent in ner_results if ent['entity_group'] == 'TEXTILE_MACHINERY']
34
+ return textile_entities
35
+
36
+ def analyze_text():
37
+ snippet = get_random_snippet(dataset_iter, tokenizer)
38
+ entities = extract_textile_machinery_entities(snippet)
39
+
40
+ # Highlight entities in the text
41
+ for ent in sorted(entities, key=lambda x: x['start'], reverse=True):
42
+ snippet = snippet[:ent['start']] + f"**{snippet['start']:ent['end']}**" + snippet[ent['end']:]
43
+
44
+ return snippet, entities
45
+
46
+ # Build Gradio interface
47
+ with gr.Blocks() as demo_interface:
48
+ gr.Markdown("# Textile Machinery Entity Recognition Demo")
49
+ gr.Markdown("Click the button below to analyze a random text snippet.")
50
+ with gr.Row():
51
+ analyze_button = gr.Button("Analyze Random Snippet")
52
+ output_text = gr.Markdown()
53
+ output_entities = gr.JSON()
54
+
55
+ analyze_button.click(fn=analyze_text, outputs=[output_text, output_entities])
56
+
57
+ demo_interface.launch()