Skip to content

Commit 67445cd

Browse files
committed
first decision tree try, new sentiment analysis plot
1 parent e0930df commit 67445cd

File tree

4 files changed

+194
-3
lines changed

4 files changed

+194
-3
lines changed

analysis/decisionTree.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
from pathlib import Path
2+
import sys
3+
import string
4+
5+
import sklearn.linear_model
6+
import sklearn.tree
7+
import sklearn.dummy
8+
from sklearn.model_selection import train_test_split
9+
from nltk import FreqDist
10+
from nltk.corpus import stopwords
11+
import pandas as pd
12+
import numpy as np
13+
from tqdm import tqdm
14+
import graphviz
15+
from analysis import loadMetadata, concatTables
16+
17+
18+
19+
STOPWORDS = stopwords.words() + list(string.punctuation) + ["'s", "--", 'applause', "’re", "--ms", "--the"]
20+
21+
22+
def getWordFrequencySpeaker(amount: int = 10) -> pd.DataFrame:
23+
metadataDF: pd.DataFrame = loadMetadata()
24+
totalDF = pd.DataFrame()
25+
26+
root = Path('corpus') / 'tables'
27+
28+
for index, row in tqdm(metadataDF.iterrows(), ncols=80, total=len(metadataDF)):
29+
30+
tablePath = root / row['linkTables']
31+
32+
tableDF = pd.read_csv(tablePath)
33+
34+
35+
fdist = FreqDist(cleanTokens(tableDF))
36+
mostCommon = fdist.most_common(amount)
37+
38+
df = pd.DataFrame(data=dict(mostCommon) | {'PERIOD': row['period'], 'SPEAKER': row['speaker']}, index=[index])
39+
40+
totalDF = totalDF.combine_first(df)
41+
42+
43+
totalDF.fillna(0, inplace=True)
44+
totalDF = totalDF.convert_dtypes()
45+
46+
return totalDF
47+
48+
49+
def cleanTokens(tableDF:pd.DataFrame) -> list:
50+
return [w.lower() for w in tableDF['LEMMA'] if not w in STOPWORDS and w.isalpha()]
51+
52+
53+
def renderDecisionTree(tree, y) -> None:
54+
sklearn.tree.export_graphviz(tree, out_file='decisiontree.dot',
55+
feature_names=df.columns,
56+
class_names=y.unique(),
57+
filled=True, rounded=True,
58+
special_characters=True)
59+
60+
with open('decisiontree.dot') as f:
61+
dot_graph = f.read()
62+
63+
graph = graphviz.Source(dot_graph)
64+
graph.render('decisionTree', format='png', cleanup=True)
65+
66+
return None
67+
68+
69+
def trainTreeClassifier(df:pd.DataFrame) -> sklearn.tree.DecisionTreeClassifier:
70+
71+
72+
73+
74+
75+
76+
return treeClf
77+
78+
79+
if __name__ == '__main__':
80+
# totalDF = getWordFrequencySpeaker()
81+
# totalDF.to_parquet('wordFreq.parquet', index=False)
82+
83+
df = pd.read_parquet('wordFreq.parquet')
84+
df = df.convert_dtypes()
85+
df.pop('PERIOD')
86+
87+
y = df.pop('SPEAKER')
88+
train, test, y_train, y_test = train_test_split(df, y, test_size=0.2, random_state=42)
89+
90+
treeClf = sklearn.tree.DecisionTreeClassifier(max_depth=9)
91+
92+
treeClf.fit(train, y_train)
93+
94+
accuracy = treeClf.score(test, y_test)
95+
print(f'Accuracy: {accuracy:.2f}')
96+
97+
for strategy in ['most_frequent', 'prior', 'stratified', 'uniform']:
98+
dummyClf = sklearn.dummy.DummyClassifier(strategy=strategy)
99+
dummyClf.fit(train, y_train)
100+
accuracy = dummyClf.score(test, y_test)
101+
print(f'Accuracy Dummy: {accuracy:.2f}')
102+
103+
# renderDecisionTree(treeClf, y)
104+
105+
106+
131 KB
Loading

analysis/sentiments.py

Lines changed: 87 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,6 @@ def sentimentViolinPeriod(show: bool = True) -> None:
293293
return None
294294

295295

296-
297296
def sentimentTTestPeriod(show: bool = True) -> None:
298297
metadataDF: pd.DataFrame = loadMetadata()
299298

@@ -331,6 +330,86 @@ def sentimentTTestPeriod(show: bool = True) -> None:
331330
return ttestResult
332331

333332

333+
def sentimentPerText(show: bool = True) -> None:
334+
metadataDF: pd.DataFrame = loadMetadata()
335+
336+
root = Path('corpus') / 'tables'
337+
n_quantiles = 5
338+
resultDF = pd.DataFrame(columns=list(range(n_quantiles)) + ['speaker'])
339+
340+
341+
layout = {'side': ['negative', 'positive'],
342+
'color': ['blue', 'red']}
343+
yaxisRange = (-0.5, 1)
344+
345+
if False:
346+
for i, row in tqdm(metadataDF.iterrows(), ncols=80, total=len(metadataDF)):
347+
tablePath = root / row['linkTables']
348+
349+
df = pd.read_csv(tablePath)
350+
351+
352+
df = df[['SENTENCE_ID', 'SENTIMENT_SENTENCE']].groupby('SENTENCE_ID').agg('mean')
353+
354+
# Add quantile groups based on the 'index' column
355+
df['quantile'] = pd.qcut(df.index, q=n_quantiles, labels=False)
356+
357+
# Calculate the mean sentiment for each quantile
358+
quantiles = df.groupby('quantile')['SENTIMENT_SENTENCE'].mean()
359+
360+
# Create a new DataFrame from the quantiles and add a new column 'test'
361+
quantiles_row = quantiles.T
362+
quantiles_row['speaker'] = row['speaker']
363+
364+
# Append the new DataFrame to the result DataFrame
365+
resultDF.loc[len(resultDF)] = quantiles_row
366+
367+
368+
369+
#print(resultDF)
370+
371+
# resultDF = resultDF.groupby('speaker').mean().T
372+
# resultDF.rename(index={'speaker': 'Quantiles'}, inplace=True)
373+
# resultDF.reset_index(inplace=True)
374+
# resultDF.to_csv('test2.csv')
375+
376+
377+
resultDF = pd.read_csv('test2.csv')
378+
print(resultDF)
379+
print(resultDF.index, resultDF.columns)
380+
381+
fig = go.Figure()
382+
383+
# Add a line trace for each column in resultDF except 'Quantiles'
384+
for column in resultDF.columns[1:]:
385+
fig.add_trace(go.Scatter(
386+
x=resultDF['Quantiles'] + 1,
387+
y=resultDF[column],
388+
mode='lines',
389+
name=column
390+
))
391+
392+
# Update layout
393+
fig.update_layout(
394+
title='Mean Sentiment per Speaker in Quantiles',
395+
xaxis_title='Quantiles',
396+
yaxis_title='Mean Sentiment',
397+
legend=dict(
398+
orientation='h', # Set the legend orientation to horizontal
399+
x=0.5, # Center the legend horizontally
400+
y=1.05, # Position the legend above the plot
401+
xanchor='center', # Anchor the legend horizontally at the center
402+
yanchor='top' # Anchor the legend vertically at the bottom
403+
)
404+
)
405+
406+
fig.show(height=800, width=1200)
407+
fig.write_image(Path('analysis/sentimentAnalysis/sentimentAnalysisQuantiles.png'),
408+
height=800, width=1200)
409+
410+
return
411+
412+
334413
def formatStatistics(ttest:scipy.stats._stats_py.TtestResult, cohensD_:float) -> str:
335414
pValue = ttest.pvalue
336415

@@ -367,6 +446,11 @@ def cohensD(series1:pd.Series, series2:pd.Series) -> float:
367446

368447
# sentimentBoxplotSpeaker()
369448
# sentimentBoxplotYear()
370-
sentimentViolinPeriod(show=False)
449+
# sentimentViolinPeriod(show=False)
450+
sentimentPerText(show=True)
451+
452+
# ttestResult = sentimentTTestPeriod(show=True)
453+
454+
455+
371456

372-
# ttestResult = sentimentTTestPeriod(show=True)

analysis/wordFrequency.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ def plotWordFrequencySpeaker(amount: int = 10, show: bool = False) -> None:
155155
posTbl = tokenTbl[tokenTbl['POS'].isin(tags)]
156156
fdist = FreqDist(posTbl['LEMMA'].str.lower())
157157
mostCommon = fdist.most_common(amount)
158+
158159

159160
fig.add_trace(go.Bar(
160161
x=[w[0] for w in mostCommon],

0 commit comments

Comments
 (0)