Spaces:
Sleeping
Sleeping
Add untracked files and synchronize with remote
Browse files- Dockerfile +12 -0
- New Text Document.txt +0 -0
- components/data_utils.py +15 -0
- components/federated_learning.py +43 -0
- components/knowledge_graph.py +36 -0
- components/model_utils.py +18 -0
- components/pubmed_search.py +65 -0
- components/vis.py +30 -0
- requirements.txt +4 -0
Dockerfile
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.10
|
2 |
+
|
3 |
+
WORKDIR /app
|
4 |
+
|
5 |
+
COPY . /app
|
6 |
+
|
7 |
+
RUN apt-get update && apt-get install -y curl
|
8 |
+
RUN pip install -r requirements.txt
|
9 |
+
|
10 |
+
RUN python -m spacy download en_core_web_sm
|
11 |
+
|
12 |
+
CMD ["python", "app.py"]
|
New Text Document.txt
ADDED
File without changes
|
components/data_utils.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def partition_data(dataset, num_clients):
|
2 |
+
"""
|
3 |
+
Partitions a dataset into `num_clients` subsets.
|
4 |
+
This is just a placeholder. Implement a more sophisticated partitioning strategy
|
5 |
+
(e.g., based on medical specialty, patient demographics) for a real application.
|
6 |
+
"""
|
7 |
+
data_per_client = len(dataset) // num_clients
|
8 |
+
remaining_data = len(dataset) % num_clients
|
9 |
+
partitioned_data = []
|
10 |
+
start_index = 0
|
11 |
+
for i in range(num_clients):
|
12 |
+
end_index = start_index + data_per_client + (1 if i < remaining_data else 0)
|
13 |
+
partitioned_data.append(dataset[start_index:end_index])
|
14 |
+
start_index = end_index
|
15 |
+
return partitioned_data
|
components/federated_learning.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import flwr as fl
|
2 |
+
import torch
|
3 |
+
from collections import OrderedDict # For the example provided.
|
4 |
+
|
5 |
+
def run_federated_learning():
|
6 |
+
"""
|
7 |
+
Sets up and starts a federated learning simulation.
|
8 |
+
This is a highly conceptual example. Actual implementation requires:
|
9 |
+
1. A defined model architecture.
|
10 |
+
2. A training loop using PyTorch or TensorFlow.
|
11 |
+
3. Data loaders.
|
12 |
+
4. Proper handling of FL strategies.
|
13 |
+
"""
|
14 |
+
|
15 |
+
class FlowerClient(fl.client.NumPyClient):
|
16 |
+
def __init__(self, model, trainloader, valloader):
|
17 |
+
self.model = model
|
18 |
+
self.trainloader = trainloader
|
19 |
+
self.valloader = valloader
|
20 |
+
|
21 |
+
def get_parameters(self, config):
|
22 |
+
return [val.cpu().numpy() for _, val in self.model.state_dict().items()]
|
23 |
+
|
24 |
+
def set_parameters(self, parameters):
|
25 |
+
params_dict = zip(self.model.state_dict().keys(), parameters)
|
26 |
+
state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
|
27 |
+
self.model.load_state_dict(state_dict, strict=True)
|
28 |
+
|
29 |
+
def fit(self, parameters, config):
|
30 |
+
self.set_parameters(parameters)
|
31 |
+
# Train.
|
32 |
+
print("Train the parameters here.")
|
33 |
+
return parameters, 1, {}
|
34 |
+
|
35 |
+
def evaluate(self, parameters, config):
|
36 |
+
self.set_parameters(parameters)
|
37 |
+
# Test (validate).
|
38 |
+
return 1,1, {"accuracy": 1}
|
39 |
+
|
40 |
+
#Flower code
|
41 |
+
#The parameters needs to be added.
|
42 |
+
|
43 |
+
print("Started Simulation FL code")
|
components/knowledge_graph.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from py2neo import Graph, Node, Relationship
|
2 |
+
import spacy
|
3 |
+
|
4 |
+
def extract_knowledge_graph(text, nlp):
|
5 |
+
"""Extracts entities and relationships and stores them to Neo4j."""
|
6 |
+
|
7 |
+
graph = Graph("bolt://localhost:7687", auth=("neo4j", "password")) # Adjust credentials
|
8 |
+
|
9 |
+
doc = nlp(text)
|
10 |
+
|
11 |
+
for ent in doc.ents:
|
12 |
+
node = Node("Entity", name=ent.text, label=ent.label_)
|
13 |
+
graph.create(node)
|
14 |
+
|
15 |
+
#This requires more work for the relationship
|
16 |
+
"""
|
17 |
+
This needs more work to make the information work.
|
18 |
+
Example only. More data cleaning needed before real implementation
|
19 |
+
|
20 |
+
for token in doc:
|
21 |
+
# Example: look for verbs connecting entities
|
22 |
+
if token.dep_ == "ROOT" and token.pos_ == "VERB":
|
23 |
+
for child in token.children:
|
24 |
+
if child.dep_ == "nsubj" and child.ent_type_: # Subject is an entity
|
25 |
+
for obj in token.children:
|
26 |
+
if obj.dep_ == "dobj" and obj.ent_type_: # Object is an entity
|
27 |
+
subject_node = Node("Entity", name=child.text, label=child.ent_type_)
|
28 |
+
object_node = Node("Entity", name=obj.text, label=obj.ent_type_)
|
29 |
+
relation = Relationship(subject_node, token.text, object_node)
|
30 |
+
graph.create(relation)
|
31 |
+
"""
|
32 |
+
|
33 |
+
print("Successfully loaded data to the knowledge base.")
|
34 |
+
|
35 |
+
# Example Node
|
36 |
+
print("Create a node called entity.")
|
components/model_utils.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import pipeline
|
2 |
+
import os
|
3 |
+
|
4 |
+
def load_summarization_model():
|
5 |
+
"""Loads the summarization model. Check for HUGGINGFACE_API_TOKEN first."""
|
6 |
+
api_token = os.environ.get("HUGGINGFACE_API_TOKEN")
|
7 |
+
model_name = "facebook/bart-large-cnn" # Or whatever
|
8 |
+
|
9 |
+
if not api_token:
|
10 |
+
print("HUGGINGFACE_API_TOKEN not found. Summarization will not work.")
|
11 |
+
return None
|
12 |
+
try:
|
13 |
+
summarizer = pipeline("summarization", model=model_name, token=api_token)
|
14 |
+
print(f"Summarization Model {model_name} Loaded...")
|
15 |
+
return summarizer
|
16 |
+
except Exception as e:
|
17 |
+
print(f"Model load error: {e}")
|
18 |
+
return None
|
components/pubmed_search.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from Bio import Entrez
|
2 |
+
import os # For environment variables and file paths
|
3 |
+
|
4 |
+
# ---------------------------- Configuration ----------------------------
|
5 |
+
|
6 |
+
# ---------------------------- Helper Functions ----------------------------
|
7 |
+
|
8 |
+
def log_error(message: str):
|
9 |
+
"""Logs an error message to the console and a file (if possible)."""
|
10 |
+
print(f"ERROR: {message}")
|
11 |
+
try:
|
12 |
+
with open("error_log.txt", "a") as f:
|
13 |
+
f.write(f"{message}\n")
|
14 |
+
except:
|
15 |
+
print("Couldn't write to error log file.") #If logging fails, still print to console
|
16 |
+
|
17 |
+
# ---------------------------- Tool Functions ----------------------------
|
18 |
+
|
19 |
+
def search_pubmed(query: str) -> list:
|
20 |
+
"""Searches PubMed and returns a list of article IDs."""
|
21 |
+
try:
|
22 |
+
Entrez.email = os.environ.get("ENTREZ_EMAIL", "[email protected]")
|
23 |
+
handle = Entrez.esearch(db="pubmed", term=query, retmax="5")
|
24 |
+
record = Entrez.read(handle)
|
25 |
+
handle.close()
|
26 |
+
return record["IdList"]
|
27 |
+
except Exception as e:
|
28 |
+
log_error(f"PubMed search error: {e}")
|
29 |
+
return [f"Error during PubMed search: {e}"]
|
30 |
+
|
31 |
+
def fetch_abstract(article_id: str) -> str:
|
32 |
+
"""Fetches the abstract for a given PubMed article ID."""
|
33 |
+
try:
|
34 |
+
Entrez.email = os.environ.get("ENTREZ_EMAIL", "[email protected]")
|
35 |
+
handle = Entrez.efetch(db="pubmed", id=article_id, rettype="abstract", retmode="text")
|
36 |
+
abstract = handle.read()
|
37 |
+
handle.close()
|
38 |
+
return abstract
|
39 |
+
except Exception as e:
|
40 |
+
log_error(f"Error fetching abstract for {article_id}: {e}")
|
41 |
+
return f"Error fetching abstract for {article_id}: {e}"
|
42 |
+
|
43 |
+
# ---------------------------- Agent Function ----------------------------
|
44 |
+
|
45 |
+
def medai_agent(query: str) -> str:
|
46 |
+
"""Orchestrates the medical literature review and presents abstract."""
|
47 |
+
article_ids = search_pubmed(query)
|
48 |
+
|
49 |
+
if isinstance(article_ids, list) and article_ids:
|
50 |
+
results = []
|
51 |
+
for article_id in article_ids:
|
52 |
+
abstract = fetch_abstract(article_id)
|
53 |
+
if "Error" not in abstract:
|
54 |
+
results.append(f"<div class='article'>\n"
|
55 |
+
f" <h3 class='article-id'>Article ID: {article_id}</h3>\n"
|
56 |
+
f" <p class='abstract'><strong>Abstract:</strong> {abstract}</p>\n"
|
57 |
+
f"</div>\n")
|
58 |
+
else:
|
59 |
+
results.append(f"<div class='article error'>\n"
|
60 |
+
f" <h3 class='article-id'>Article ID: {article_id}</h3>\n"
|
61 |
+
f" <p class='error-message'>Error processing article: {abstract}</p>\n"
|
62 |
+
f"</div>\n")
|
63 |
+
return "\n".join(results)
|
64 |
+
else:
|
65 |
+
return f"No articles found or error occurred: {article_ids}"
|
components/vis.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib.pyplot as plt
|
2 |
+
import io
|
3 |
+
import base64
|
4 |
+
|
5 |
+
def generate_federated_learning_plot(client_accuracies):
|
6 |
+
"""
|
7 |
+
Generates a plot showing the training accuracy of each client in a federated learning setting.
|
8 |
+
This is a placeholder. You'll need to integrate it with your actual FL framework
|
9 |
+
and store the client accuracies during training.
|
10 |
+
"""
|
11 |
+
# Assuming client_accuracies is a dictionary of client_id: accuracy
|
12 |
+
client_ids = list(client_accuracies.keys())
|
13 |
+
accuracies = list(client_accuracies.values())
|
14 |
+
|
15 |
+
plt.figure(figsize=(10, 6))
|
16 |
+
plt.bar(client_ids, accuracies, color='skyblue')
|
17 |
+
plt.xlabel('Client ID')
|
18 |
+
plt.ylabel('Accuracy')
|
19 |
+
plt.title('Federated Learning: Client Accuracies')
|
20 |
+
plt.ylim(0, 1) # Assuming accuracy is between 0 and 1
|
21 |
+
plt.xticks(rotation=45, ha='right')
|
22 |
+
plt.tight_layout()
|
23 |
+
|
24 |
+
# Convert plot to base64 image
|
25 |
+
img_buf = io.BytesIO()
|
26 |
+
plt.savefig(img_buf, format='png')
|
27 |
+
img_buf.seek(0)
|
28 |
+
img_data = base64.b64encode(img_buf.read()).decode('utf-8')
|
29 |
+
plt.close() # Close the plot to free memory
|
30 |
+
return f'<img src="data:image/png;base64,{img_data}" alt="Federated Learning Plot"/>'
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
transformers
|
3 |
+
biopython
|
4 |
+
spacy
|