mgbam commited on
Commit
9c7387c
·
1 Parent(s): 1b899cb

Add untracked files and synchronize with remote

Browse files
Dockerfile ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10
2
+
3
+ WORKDIR /app
4
+
5
+ COPY . /app
6
+
7
+ RUN apt-get update && apt-get install -y curl
8
+ RUN pip install -r requirements.txt
9
+
10
+ RUN python -m spacy download en_core_web_sm
11
+
12
+ CMD ["python", "app.py"]
New Text Document.txt ADDED
File without changes
components/data_utils.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def partition_data(dataset, num_clients):
2
+ """
3
+ Partitions a dataset into `num_clients` subsets.
4
+ This is just a placeholder. Implement a more sophisticated partitioning strategy
5
+ (e.g., based on medical specialty, patient demographics) for a real application.
6
+ """
7
+ data_per_client = len(dataset) // num_clients
8
+ remaining_data = len(dataset) % num_clients
9
+ partitioned_data = []
10
+ start_index = 0
11
+ for i in range(num_clients):
12
+ end_index = start_index + data_per_client + (1 if i < remaining_data else 0)
13
+ partitioned_data.append(dataset[start_index:end_index])
14
+ start_index = end_index
15
+ return partitioned_data
components/federated_learning.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import flwr as fl
2
+ import torch
3
+ from collections import OrderedDict # For the example provided.
4
+
5
+ def run_federated_learning():
6
+ """
7
+ Sets up and starts a federated learning simulation.
8
+ This is a highly conceptual example. Actual implementation requires:
9
+ 1. A defined model architecture.
10
+ 2. A training loop using PyTorch or TensorFlow.
11
+ 3. Data loaders.
12
+ 4. Proper handling of FL strategies.
13
+ """
14
+
15
+ class FlowerClient(fl.client.NumPyClient):
16
+ def __init__(self, model, trainloader, valloader):
17
+ self.model = model
18
+ self.trainloader = trainloader
19
+ self.valloader = valloader
20
+
21
+ def get_parameters(self, config):
22
+ return [val.cpu().numpy() for _, val in self.model.state_dict().items()]
23
+
24
+ def set_parameters(self, parameters):
25
+ params_dict = zip(self.model.state_dict().keys(), parameters)
26
+ state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
27
+ self.model.load_state_dict(state_dict, strict=True)
28
+
29
+ def fit(self, parameters, config):
30
+ self.set_parameters(parameters)
31
+ # Train.
32
+ print("Train the parameters here.")
33
+ return parameters, 1, {}
34
+
35
+ def evaluate(self, parameters, config):
36
+ self.set_parameters(parameters)
37
+ # Test (validate).
38
+ return 1,1, {"accuracy": 1}
39
+
40
+ #Flower code
41
+ #The parameters needs to be added.
42
+
43
+ print("Started Simulation FL code")
components/knowledge_graph.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from py2neo import Graph, Node, Relationship
2
+ import spacy
3
+
4
+ def extract_knowledge_graph(text, nlp):
5
+ """Extracts entities and relationships and stores them to Neo4j."""
6
+
7
+ graph = Graph("bolt://localhost:7687", auth=("neo4j", "password")) # Adjust credentials
8
+
9
+ doc = nlp(text)
10
+
11
+ for ent in doc.ents:
12
+ node = Node("Entity", name=ent.text, label=ent.label_)
13
+ graph.create(node)
14
+
15
+ #This requires more work for the relationship
16
+ """
17
+ This needs more work to make the information work.
18
+ Example only. More data cleaning needed before real implementation
19
+
20
+ for token in doc:
21
+ # Example: look for verbs connecting entities
22
+ if token.dep_ == "ROOT" and token.pos_ == "VERB":
23
+ for child in token.children:
24
+ if child.dep_ == "nsubj" and child.ent_type_: # Subject is an entity
25
+ for obj in token.children:
26
+ if obj.dep_ == "dobj" and obj.ent_type_: # Object is an entity
27
+ subject_node = Node("Entity", name=child.text, label=child.ent_type_)
28
+ object_node = Node("Entity", name=obj.text, label=obj.ent_type_)
29
+ relation = Relationship(subject_node, token.text, object_node)
30
+ graph.create(relation)
31
+ """
32
+
33
+ print("Successfully loaded data to the knowledge base.")
34
+
35
+ # Example Node
36
+ print("Create a node called entity.")
components/model_utils.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import pipeline
2
+ import os
3
+
4
+ def load_summarization_model():
5
+ """Loads the summarization model. Check for HUGGINGFACE_API_TOKEN first."""
6
+ api_token = os.environ.get("HUGGINGFACE_API_TOKEN")
7
+ model_name = "facebook/bart-large-cnn" # Or whatever
8
+
9
+ if not api_token:
10
+ print("HUGGINGFACE_API_TOKEN not found. Summarization will not work.")
11
+ return None
12
+ try:
13
+ summarizer = pipeline("summarization", model=model_name, token=api_token)
14
+ print(f"Summarization Model {model_name} Loaded...")
15
+ return summarizer
16
+ except Exception as e:
17
+ print(f"Model load error: {e}")
18
+ return None
components/pubmed_search.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from Bio import Entrez
2
+ import os # For environment variables and file paths
3
+
4
+ # ---------------------------- Configuration ----------------------------
5
+
6
+ # ---------------------------- Helper Functions ----------------------------
7
+
8
+ def log_error(message: str):
9
+ """Logs an error message to the console and a file (if possible)."""
10
+ print(f"ERROR: {message}")
11
+ try:
12
+ with open("error_log.txt", "a") as f:
13
+ f.write(f"{message}\n")
14
+ except:
15
+ print("Couldn't write to error log file.") #If logging fails, still print to console
16
+
17
+ # ---------------------------- Tool Functions ----------------------------
18
+
19
+ def search_pubmed(query: str) -> list:
20
+ """Searches PubMed and returns a list of article IDs."""
21
+ try:
22
+ Entrez.email = os.environ.get("ENTREZ_EMAIL", "[email protected]")
23
+ handle = Entrez.esearch(db="pubmed", term=query, retmax="5")
24
+ record = Entrez.read(handle)
25
+ handle.close()
26
+ return record["IdList"]
27
+ except Exception as e:
28
+ log_error(f"PubMed search error: {e}")
29
+ return [f"Error during PubMed search: {e}"]
30
+
31
+ def fetch_abstract(article_id: str) -> str:
32
+ """Fetches the abstract for a given PubMed article ID."""
33
+ try:
34
+ Entrez.email = os.environ.get("ENTREZ_EMAIL", "[email protected]")
35
+ handle = Entrez.efetch(db="pubmed", id=article_id, rettype="abstract", retmode="text")
36
+ abstract = handle.read()
37
+ handle.close()
38
+ return abstract
39
+ except Exception as e:
40
+ log_error(f"Error fetching abstract for {article_id}: {e}")
41
+ return f"Error fetching abstract for {article_id}: {e}"
42
+
43
+ # ---------------------------- Agent Function ----------------------------
44
+
45
+ def medai_agent(query: str) -> str:
46
+ """Orchestrates the medical literature review and presents abstract."""
47
+ article_ids = search_pubmed(query)
48
+
49
+ if isinstance(article_ids, list) and article_ids:
50
+ results = []
51
+ for article_id in article_ids:
52
+ abstract = fetch_abstract(article_id)
53
+ if "Error" not in abstract:
54
+ results.append(f"<div class='article'>\n"
55
+ f" <h3 class='article-id'>Article ID: {article_id}</h3>\n"
56
+ f" <p class='abstract'><strong>Abstract:</strong> {abstract}</p>\n"
57
+ f"</div>\n")
58
+ else:
59
+ results.append(f"<div class='article error'>\n"
60
+ f" <h3 class='article-id'>Article ID: {article_id}</h3>\n"
61
+ f" <p class='error-message'>Error processing article: {abstract}</p>\n"
62
+ f"</div>\n")
63
+ return "\n".join(results)
64
+ else:
65
+ return f"No articles found or error occurred: {article_ids}"
components/vis.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import io
3
+ import base64
4
+
5
+ def generate_federated_learning_plot(client_accuracies):
6
+ """
7
+ Generates a plot showing the training accuracy of each client in a federated learning setting.
8
+ This is a placeholder. You'll need to integrate it with your actual FL framework
9
+ and store the client accuracies during training.
10
+ """
11
+ # Assuming client_accuracies is a dictionary of client_id: accuracy
12
+ client_ids = list(client_accuracies.keys())
13
+ accuracies = list(client_accuracies.values())
14
+
15
+ plt.figure(figsize=(10, 6))
16
+ plt.bar(client_ids, accuracies, color='skyblue')
17
+ plt.xlabel('Client ID')
18
+ plt.ylabel('Accuracy')
19
+ plt.title('Federated Learning: Client Accuracies')
20
+ plt.ylim(0, 1) # Assuming accuracy is between 0 and 1
21
+ plt.xticks(rotation=45, ha='right')
22
+ plt.tight_layout()
23
+
24
+ # Convert plot to base64 image
25
+ img_buf = io.BytesIO()
26
+ plt.savefig(img_buf, format='png')
27
+ img_buf.seek(0)
28
+ img_data = base64.b64encode(img_buf.read()).decode('utf-8')
29
+ plt.close() # Close the plot to free memory
30
+ return f'<img src="data:image/png;base64,{img_data}" alt="Federated Learning Plot"/>'
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio
2
+ transformers
3
+ biopython
4
+ spacy